diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 0897d7f37c..da5fa39335 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -3,41 +3,43 @@ # Analysis period: 180 days # Minimum commits threshold: 1 -benchmarks/ @bkryu @cyx-6 @nv-yunzheq @kahyunnam @jiahanc -benchmarks/routines/ @bkryu @nv-yunzheq @cyx-6 @nvmbreughe @Anerudhan +benchmarks/ @bkryu @jiahanc @cyx-6 @kahyunnam @yzh119 +benchmarks/routines/ @bkryu @nv-yunzheq @jiahanc @cyx-6 @nvmbreughe ci/ @cyx-6 @yzh119 @nvmbreughe ci/scripts/ @cyx-6 ci/scripts/jenkins/ @cyx-6 -csrc/ @wenscarl @yzh119 @cyx-6 @djmmoss @yongwww -csrc/fused_moe/ @yzh119 @yongwww @djmmoss @wenscarl @cyx-6 -csrc/fused_moe/cutlass_backend/ @yzh119 @yongwww @djmmoss @wenscarl @cyx-6 -csrc/nv_internal/ @wenscarl @djmmoss @yzh119 @cyx-6 @yongwww -csrc/nv_internal/cpp/ @wenscarl @yongwww @djmmoss @joker-eph @ttyio -csrc/nv_internal/include/ @wenscarl -csrc/nv_internal/tensorrt_llm/ @wenscarl @djmmoss @yzh119 @cyx-6 @yongwww -csrc/xqa/ @yzh119 @cyx-6 -docs/ @yzh119 @cyx-6 @wenscarl @nv-yunzheq @aleozlx -flashinfer/ @yzh119 @cyx-6 @wenscarl @nvmbreughe @bkryu +csrc/ @yzh119 @wenscarl @djmmoss @cyx-6 @nv-yunzheq +csrc/fused_moe/ @yzh119 @nv-yunzheq @djmmoss @wenscarl @yongwww +csrc/fused_moe/cutlass_backend/ @yzh119 @nv-yunzheq @djmmoss @wenscarl @yongwww +csrc/nv_internal/ @wenscarl @djmmoss @yzh119 @nv-yunzheq @yongwww +csrc/nv_internal/cpp/ @wenscarl @bkryu @yongwww @djmmoss @joker-eph +csrc/nv_internal/include/ @wenscarl @nv-yunzheq +csrc/nv_internal/tensorrt_llm/ @wenscarl @djmmoss @yzh119 @nv-yunzheq @yongwww +csrc/xqa/ @cyx-6 @yzh119 +docs/ @yzh119 @cyx-6 @bkryu @wenscarl @nv-yunzheq +flashinfer/ @yzh119 @cyx-6 @wenscarl @nvmbreughe @aleozlx flashinfer-cubin/ @yzh119 @cyx-6 flashinfer-cubin/flashinfer_cubin/ @yzh119 flashinfer-jit-cache/ @yzh119 @cyx-6 flashinfer-jit-cache/flashinfer_jit_cache/ @yzh119 flashinfer/comm/ @yzh119 @cyx-6 @nvmbreughe @wenscarl @djmmoss -flashinfer/cudnn/ @Anerudhan @yzh119 @cyx-6 @Anerudhan +flashinfer/cudnn/ @Anerudhan @yzh119 @bkryu @cyx-6 @Anerudhan flashinfer/cute_dsl/ @yzh119 @kaixih @Amir-19 @aleozlx -flashinfer/fused_moe/ @djmmoss @yzh119 @cyx-6 @wenscarl @IwakuraRein -flashinfer/jit/ @yzh119 @cyx-6 @djmmoss @aleozlx @yongwww -flashinfer/jit/attention/ @yzh119 @Anerudhan @joker-eph -flashinfer/jit/gemm/ @yzh119 +flashinfer/dsv3_ops/ @nv-yunzheq @nvmbreughe +flashinfer/fused_moe/ @yzh119 @nv-yunzheq @jiahanc @djmmoss @cyx-6 +flashinfer/gemm/ @nvmbreughe @bkryu +flashinfer/jit/ @yzh119 @cyx-6 @aleozlx @nv-yunzheq @jiahanc +flashinfer/jit/attention/ @yzh119 @cyx-6 @Anerudhan +flashinfer/jit/gemm/ @yzh119 @nv-yunzheq @jiahanc flashinfer/logits_processor/ @cyx-6 @yzh119 flashinfer/profiler/ @cyx-6 -flashinfer/triton/ @cyx-6 @nvmbreughe @yzh119 +flashinfer/triton/ @nvmbreughe @cyx-6 flashinfer/tuning_configs/ @kaixih -include/ @yzh119 @cyx-6 @wenscarl @kahyunnam @joker-eph -include/flashinfer/ @yzh119 @cyx-6 @wenscarl @kahyunnam @joker-eph +include/ @yzh119 @kahyunnam @jiahanc @IwakuraRein @nv-yunzheq +include/flashinfer/ @yzh119 @kahyunnam @jiahanc @IwakuraRein @nv-yunzheq include/flashinfer/attention/ @yzh119 @kahyunnam @joker-eph include/flashinfer/comm/ @yongwww @nvmbreughe @djmmoss @yzh119 @cyx-6 -include/flashinfer/gemm/ @ttyio @yongwww @aleozlx @cyx-6 -include/flashinfer/trtllm/ @joker-eph @aleozlx @yzh119 @cyx-6 @wenscarl +include/flashinfer/gemm/ @ttyio @yongwww @yzh119 @nvmbreughe @aleozlx +include/flashinfer/trtllm/ @jiahanc @joker-eph @aleozlx @yzh119 @IwakuraRein profiler/ @cyx-6 -scripts/ @yzh119 @nvmbreughe @dierksen @yongwww @bkryu +scripts/ @yzh119 @nvmbreughe @kahyunnam @dierksen @yongwww diff --git a/.github/workflows/nightly-release.yml b/.github/workflows/nightly-release.yml index 4d5acdfe63..7c57d4bd7a 100644 --- a/.github/workflows/nightly-release.yml +++ b/.github/workflows/nightly-release.yml @@ -145,7 +145,7 @@ jobs: - name: Build wheel in container env: DOCKER_IMAGE: ${{ matrix.arch == 'aarch64' && format('pytorch/manylinuxaarch64-builder:cuda{0}', matrix.cuda) || format('pytorch/manylinux2_28-builder:cuda{0}', matrix.cuda) }} - FLASHINFER_CUDA_ARCH_LIST: ${{ matrix.cuda == '12.8' && '7.5 8.0 8.9 9.0a 10.0a 12.0a' || '7.5 8.0 8.9 9.0a 10.0a 10.3a 12.0a' }} + FLASHINFER_CUDA_ARCH_LIST: ${{ matrix.cuda < '13.0' && '7.5 8.0 8.9 9.0a 10.0a 12.0a' || '7.5 8.0 8.9 9.0a 10.0a 10.3a 11.0a 12.0f' }} FLASHINFER_DEV_RELEASE_SUFFIX: ${{ needs.setup.outputs.dev_suffix }} run: | # Extract CUDA major and minor versions diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 7e406ff2ac..b11e72e1f7 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -182,7 +182,7 @@ jobs: - name: Build wheel in container env: DOCKER_IMAGE: ${{ matrix.arch == 'aarch64' && format('pytorch/manylinuxaarch64-builder:cuda{0}', matrix.cuda) || format('pytorch/manylinux2_28-builder:cuda{0}', matrix.cuda) }} - FLASHINFER_CUDA_ARCH_LIST: ${{ matrix.cuda == '12.8' && '7.5 8.0 8.9 9.0a 10.0a 12.0a' || '7.5 8.0 8.9 9.0a 10.0a 10.3a 12.0a' }} + FLASHINFER_CUDA_ARCH_LIST: ${{ matrix.cuda < '13.0' && '7.5 8.0 8.9 9.0a 10.0a 12.0a' || '7.5 8.0 8.9 9.0a 10.0a 10.3a 11.0a 12.0f' }} run: | # Extract CUDA major and minor versions CUDA_MAJOR=$(echo "${{ matrix.cuda }}" | cut -d'.' -f1) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml old mode 100644 new mode 100755 diff --git a/README.md b/README.md index 8f93c97f7a..b620c1481d 100644 --- a/README.md +++ b/README.md @@ -15,12 +15,12 @@ Kernel Library for LLM Serving [![Build Status](https://ci.tlcpack.ai/job/flashinfer-ci/job/main/badge/icon)](https://ci.tlcpack.ai/job/flashinfer-ci/job/main/) [![Documentation](https://github.com/flashinfer-ai/flashinfer/actions/workflows/build-doc.yml/badge.svg)](https://github.com/flashinfer-ai/flashinfer/actions/workflows/build-doc.yml) - FlashInfer is a library and kernel generator for Large Language Models that provides high-performance implementation of LLM GPU kernels such as FlashAttention, SparseAttention, PageAttention, Sampling, and more. FlashInfer focuses on LLM serving and inference, and delivers state-of-the-art performance across diverse scenarios. Check our [v0.2 release blog](https://flashinfer.ai/2024/12/16/flashinfer-v02-release.html) for new features! The core features of FlashInfer include: + 1. **Efficient Sparse/Dense Attention Kernels**: Efficient single/batch attention for sparse(paged)/dense KV-storage on CUDA Cores and Tensor Cores (both FA2 & FA3) templates. The vector-sparse attention can achieve 90% of the bandwidth of dense kernels with same problem size. 2. **Load-Balanced Scheduling**: FlashInfer decouples `plan`/`run` stage of attention computation where we schedule the computation of variable-length inputs in `plan` stage to alleviate load-imbalance issue. 3. **Memory Efficiency**: FlashInfer offers [Cascade Attention](https://docs.flashinfer.ai/api/cascade.html#flashinfer.cascade.MultiLevelCascadeAttentionWrapper) for hierarchical KV-Cache, and implements Head-Query fusion for accelerating Grouped-Query Attention, and efficient kernels for low-precision attention and fused-RoPE attention for compressed KV-Cache. @@ -31,6 +31,7 @@ The core features of FlashInfer include: FlashInfer supports PyTorch, TVM and C++ (header-only) APIs, and can be easily integrated into existing projects. ## News + - [Mar 10, 2025] [Blog Post](https://flashinfer.ai/2025/03/10/sampling.html) Sorting-Free GPU Kernels for LLM Sampling, which explains the design of sampling kernels in FlashInfer. - [Mar 1, 2025] Checkout flashinfer's [intra-kernel profiler](https://github.com/flashinfer-ai/flashinfer/tree/main/profiler) for visualizing the timeline of each threadblock in GPU kernels. - [Dec 16, 2024] [Blog Post](https://flashinfer.ai/2024/12/16/flashinfer-v02-release.html) FlashInfer 0.2 - Efficient and Customizable Kernels for LLM Inference Serving @@ -51,11 +52,13 @@ pip install flashinfer-python ``` **Package Options:** + - **flashinfer-python**: Core package that compiles/downloads kernels on first use - **flashinfer-cubin**: Pre-compiled kernel binaries for all supported GPU architectures - **flashinfer-jit-cache**: Pre-built kernel cache for specific CUDA versions **For faster initialization and offline usage**, install the optional packages to have most kernels pre-compiled: + ```bash pip install flashinfer-python flashinfer-cubin # JIT cache package (replace cu129 with your CUDA version: cu128, cu129, or cu130) @@ -75,6 +78,7 @@ python -m pip install -v . ``` **For development**, install in editable mode: + ```bash python -m pip install --no-build-isolation -e . -v ``` @@ -82,6 +86,7 @@ python -m pip install --no-build-isolation -e . -v **Build optional packages:** `flashinfer-cubin`: + ```bash cd flashinfer-cubin python -m build --no-isolation --wheel @@ -89,8 +94,9 @@ python -m pip install dist/*.whl ``` `flashinfer-jit-cache` (customize `FLASHINFER_CUDA_ARCH_LIST` for your target GPUs): + ```bash -export FLASHINFER_CUDA_ARCH_LIST="7.5 8.0 8.9 10.0a 10.3a 12.0a" +export FLASHINFER_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 10.3a 11.0a 12.0f" cd flashinfer-jit-cache python -m build --no-isolation --wheel python -m pip install dist/*.whl @@ -120,6 +126,7 @@ flashinfer show-config ``` This command displays: + - FlashInfer version and installed packages (flashinfer-python, flashinfer-cubin, flashinfer-jit-cache) - PyTorch and CUDA version information - Environment variables and artifact paths @@ -162,17 +169,36 @@ o = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=False) # prefill att Check out [documentation](https://docs.flashinfer.ai/) for usage of batch decode/append/prefill kernels and shared-prefix cascading kernels. +## API Logging + +FlashInfer provides comprehensive API logging for debugging. Enable it using environment variables: + +```bash +# Enable logging (levels: 0=off (default), 1=basic, 3=detailed, 5=statistics) +export FLASHINFER_LOGLEVEL=3 + +# Set log destination (stdout (default), stderr, or file path) +export FLASHINFER_LOGDEST=stdout +``` + +For detailed information about logging levels, configuration, and advanced features, see [Logging](https://docs.flashinfer.ai/logging.html) in our documentation. + ## Custom Attention Variants Starting from FlashInfer v0.2, users can customize their own attention variants with additional parameters. For more details, refer to our [JIT examples](https://github.com/flashinfer-ai/flashinfer/blob/main/tests/utils/test_jit_example.py). -## GPU Support +## GPU and CUDA Support FlashInfer currently provides support for NVIDIA SM architectures 75 and higher and beta support for 103, 110, 120, and 121. +**Supported CUDA Versions:** 12.6, 12.8, 13.0, 13.1 + +> **Note:** FlashInfer strives to follow PyTorch's supported CUDA versions plus the latest CUDA release. + ## Adoption We are thrilled to share that FlashInfer is being adopted by many cutting-edge projects, including but not limited to: + - [MLC-LLM](https://github.com/mlc-ai/mlc-llm) - [Punica](https://github.com/punica-ai/punica) - [SGLang](https://github.com/sgl-project/sglang) diff --git a/benchmarks/README.md b/benchmarks/README.md index f41d695cdc..d81e9c3642 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -117,7 +117,7 @@ The output CSV will contain detailed metrics including: | `--verbose`, `-v` | Print additional information (can be used multiple times for more verbosity, e.g. `-vv`) | | `--case_tag` | Optional tag for the test case, useful for annotating or filtering results in the output CSV. | | `--generate_repro_command`| If set, prints a reproducer command for the test case and stores it in the output CSV. | -| `--backends` | Space-separated list of backends to test, e.g. fa2, fa2_tc, fa3, cudnn, cutlass, trtllm, trtllm-gen, trtllm-gen-native, cublas| +| `--backends` | Space-separated list of backends to test, e.g. fa2, fa2_tc, fa3, cudnn, cutlass, trtllm, trtllm-gen, trtllm-native, cublas| ### Attention Flags | Flag | Description | @@ -166,8 +166,7 @@ The output CSV will contain detailed metrics including: | `--topk_group` | Number of groups to consider for top-k routing. Default: 1 | | `--routed_scaling_factor`| Scaling factor for routing. Default: 2.5 | | `--local_expert_offset` | Offset of local experts in global expert space. Default: 0 | -| `--local_num_experts` | Number of experts handled by this device. Default: equals num_experts | -| `--tile_tokens_dim` | Tile dimension for tokens. Default: 8 | +| `--local_num_experts` | Number of experts handled by this device. Default: equals num_experts | | | `--routing_method` | Routing method: `renormalize`, `deepseek_v3`, `llama4`, `renormalize_naive`. Default: `deepseek_v3`. | | `--use_shuffled_weight` | Whether to use shuffled weight layout | | `--weight_layout` | Weight layout: 0=MajorK, 1=MajorMn, 2=BlockMajorK. Default: 0 | @@ -213,14 +212,14 @@ Legend: - cutlass: CUTLASS - trtllm: TensorRT-LLM - trtllm-gen: TensorRT-LLM (generic wrapper) -- trtllm-gen-native: TensorRT-LLM (native API) +- trtllm-native: TensorRT-LLM (native API) --> | Routine | 7.5 | 8.0 | 8.6 | 8.9 | 9.0 | 10.0 | 10.3 | 12.0 | |---------|-----|-----|-----|-----|-----|-------|-------|-------| -| **BatchDecodeWithPagedKVCacheWrapper** | fa2 | fa2, fa2_tc, cudnn | fa2, fa2_tc, cudnn | fa2, fa2_tc, cudnn | fa2, fa2_tc, cudnn | fa2, fa2_tc, cudnn, trtllm-gen, trtllm-gen-native | fa2, fa2_tc, cudnn, trtllm-gen, trtllm-gen-native | fa2, fa2_tc, cudnn | -| **BatchPrefillWithPagedKVCacheWrapper** | | fa2, cudnn | fa2, cudnn | fa2, cudnn | fa2, fa3, cudnn | fa2, cudnn, trtllm-gen, trtllm-gen-native | fa2, cudnn, trtllm-gen, trtllm-gen-native | fa2, cudnn | -| **BatchPrefillWithRaggedKVCacheWrapper** | | fa2, cudnn | fa2, cudnn | fa2, cudnn | fa2, fa3, cudnn | fa2, cudnn, cutlass, trtllm-gen-native | fa2, cudnn, cutlass, trtllm-gen-native | fa2, cudnn | -| **BatchMLAPagedAttentionWrapper** | | fa2 | fa2 | fa2 | fa2, fa3 | fa2, cutlass, trtllm-gen-native | fa2, cutlass, trtllm-gen-native | fa2 | +| **BatchDecodeWithPagedKVCacheWrapper** | fa2 | fa2, fa2_tc, cudnn | fa2, fa2_tc, cudnn | fa2, fa2_tc, cudnn | fa2, fa2_tc, cudnn | fa2, fa2_tc, cudnn, trtllm-gen, trtllm-native | fa2, fa2_tc, cudnn, trtllm-gen, trtllm-native | fa2, fa2_tc, cudnn | +| **BatchPrefillWithPagedKVCacheWrapper** | | fa2, cudnn | fa2, cudnn | fa2, cudnn | fa2, fa3, cudnn | fa2, cudnn, trtllm-gen, trtllm-native | fa2, cudnn, trtllm-gen, trtllm-native | fa2, cudnn | +| **BatchPrefillWithRaggedKVCacheWrapper** | | fa2, cudnn | fa2, cudnn | fa2, cudnn | fa2, fa3, cudnn | fa2, cudnn, cutlass, trtllm-native | fa2, cudnn, cutlass, trtllm-native | fa2, cudnn | +| **BatchMLAPagedAttentionWrapper** | | fa2 | fa2 | fa2 | fa2, fa3 | fa2, cutlass, trtllm-native | fa2, cutlass, trtllm-native | fa2 | | **gemm_fp8_nt_groupwise** | | | | | | cutlass | cutlass | | | **group_gemm_fp8_nt_groupwise** | | | | | | cutlass | cutlass | | | **bmm_fp8** | | | | cudnn, cublas | cudnn, cublas | cudnn, cublas, cutlass | cudnn, cublas, cutlass | cudnn, cublas | @@ -238,4 +237,4 @@ Backend Legend: - cutlass: CUTLASS - trtllm: TensorRT-LLM - trtllm-gen: TensorRT-LLM -- trtllm-gen-native: TensorRT-LLM (out-of-wrapper) +- trtllm-native: TensorRT-LLM (out-of-wrapper) diff --git a/benchmarks/bench_batch_attention.py b/benchmarks/bench_batch_attention.py index 2c1071d808..c94a86eacc 100644 --- a/benchmarks/bench_batch_attention.py +++ b/benchmarks/bench_batch_attention.py @@ -436,7 +436,7 @@ def main(args: argparse.Namespace) -> None: records_new = [] records_separate = [] for cfg_id, (decode_case, prefill_case) in enumerate( - zip(decode_lens, prefill_lens), start=1 + zip(decode_lens, prefill_lens, strict=True), start=1 ): prefill_kv_lens = [p[0] for p in prefill_case] prefill_qo_lens = [p[1] for p in prefill_case] diff --git a/benchmarks/bench_blackwell_attention.py b/benchmarks/bench_blackwell_attention.py index 52452e05a8..73b0cd0b3c 100644 --- a/benchmarks/bench_blackwell_attention.py +++ b/benchmarks/bench_blackwell_attention.py @@ -18,7 +18,10 @@ import torch import flashinfer -from flashinfer.testing.utils import bench_gpu_time +from flashinfer.testing.utils import ( + bench_gpu_time, + attention_tflops_per_sec_with_actual_seq_lens, +) def bench_fmha_blackwell( @@ -69,14 +72,17 @@ def bench_fmha_blackwell( ) ms = np.median(measurements) - def flops(ms): - if causal: - return batch_size * qkv_len * qkv_len * num_heads * head_dim * 2 / ms / 1e9 - else: - return batch_size * qkv_len * qkv_len * num_heads * head_dim * 4 / ms / 1e9 - + TFLOPS = attention_tflops_per_sec_with_actual_seq_lens( + torch.full((batch_size,), qkv_len), + torch.full((batch_size,), qkv_len), + head_dim, + head_dim, + num_heads, + causal, + ms, + ) print( - f"bench_fmha_blackwell (batch_size={batch_size}, qkv_len={qkv_len}, num_heads={num_heads}, head_dim={head_dim}, causal={causal}), flops: {flops(ms):.3f} TFLOPs/s" + f"bench_fmha_blackwell (batch_size={batch_size}, qkv_len={qkv_len}, num_heads={num_heads}, head_dim={head_dim}, causal={causal}), flops: {TFLOPS:.3f} TFLOPs/s" ) diff --git a/benchmarks/bench_block_sparse_attention.py b/benchmarks/bench_block_sparse_attention.py index e2a51012f5..2da2478a6f 100644 --- a/benchmarks/bench_block_sparse_attention.py +++ b/benchmarks/bench_block_sparse_attention.py @@ -18,7 +18,10 @@ import torch import flashinfer -from flashinfer.testing.utils import bench_gpu_time +from flashinfer.testing.utils import ( + bench_gpu_time, + attention_tflops_per_sec_with_actual_seq_lens, +) def bench_variable_block_sparse_attention( @@ -120,7 +123,15 @@ def bench_variable_block_sparse_attention( ) def flops(ms): - return seq_len * seq_len * num_qo_heads * head_dim * 4 / ms / 1e9 + return attention_tflops_per_sec_with_actual_seq_lens( + torch.tensor([seq_len]), + torch.tensor([seq_len]), + head_dim, + head_dim, + num_qo_heads, + False, + ms, + ) print( f"bench_variable_block_sparse_attention (num_qo_heads={num_qo_heads}, num_kv_heads={num_kv_heads}, head_dim={head_dim}, seq_len={seq_len}, num_blocks_row={num_blocks_row}, num_blocks_col={num_blocks_col}, block_density={block_density}), sparse fa2-template: {flops(sparse_ms_fa2):.3f} TFLOPs/s, sparse fa3-template: {flops(sparse_ms_fa3):.3f} TFLOPs/s, dense fa2-template: {flops(dense_sm80_ms):.3f} TFLOPs/s, dense fa3-template: {flops(dense_sm90_ms):.3f} TFLOPs/s" diff --git a/benchmarks/bench_hopper_attention.py b/benchmarks/bench_hopper_attention.py index 6ad2fdaa1b..c1e56e6225 100644 --- a/benchmarks/bench_hopper_attention.py +++ b/benchmarks/bench_hopper_attention.py @@ -18,7 +18,10 @@ import torch import flashinfer -from flashinfer.testing.utils import bench_gpu_time +from flashinfer.testing.utils import ( + bench_gpu_time, + attention_tflops_per_sec_with_actual_seq_lens, +) def bench_single_prefill(seq_len, num_heads, causal, head_dim): @@ -41,10 +44,15 @@ def bench_single_prefill(seq_len, num_heads, causal, head_dim): ) def flops(ms): - if causal: - return seq_len * seq_len * num_qo_heads * head_dim * 2 / ms / 1e9 - else: - return seq_len * seq_len * num_qo_heads * head_dim * 4 / ms / 1e9 + return attention_tflops_per_sec_with_actual_seq_lens( + torch.tensor([seq_len]), + torch.tensor([seq_len]), + head_dim, + head_dim, + num_qo_heads, + causal, + ms, + ) print( f"bench_single_prefill (seq_len={seq_len}, num_heads={num_heads}, causal={causal}, head_dim={head_dim}), fa2-template: {flops(sm80_ms):.3f} TFLOPs/s, fa3-template: {flops(sm90_ms):.3f} TFLOPs/s" @@ -97,14 +105,15 @@ def bench_batch_ragged_prefill(batch_size, num_heads, seq_len, causal, head_dim) ) def flops(ms): - if causal: - return ( - batch_size * seq_len * seq_len * num_qo_heads * head_dim * 2 / ms / 1e9 - ) - else: - return ( - batch_size * seq_len * seq_len * num_qo_heads * head_dim * 4 / ms / 1e9 - ) + return attention_tflops_per_sec_with_actual_seq_lens( + torch.full((batch_size,), seq_len), + torch.full((batch_size,), seq_len), + head_dim, + head_dim, + num_qo_heads, + causal, + ms, + ) print( f"bench_batch_ragged_prefill (batch_size={batch_size}, num_heads={num_heads}, seq_len={seq_len}, causal={causal}, head_dim={head_dim}), fa2-template: {flops(sm80_ms):.3f} TFLOPs/s, fa3-template: {flops(sm90_ms):.3f} TFLOPs/s" @@ -176,14 +185,15 @@ def bench_batch_paged_prefill( ) def flops(ms): - if causal: - return ( - batch_size * seq_len * seq_len * num_qo_heads * head_dim * 2 / ms / 1e9 - ) - else: - return ( - batch_size * seq_len * seq_len * num_qo_heads * head_dim * 4 / ms / 1e9 - ) + return attention_tflops_per_sec_with_actual_seq_lens( + torch.full((batch_size,), seq_len), + torch.full((batch_size,), seq_len), + head_dim, + head_dim, + num_qo_heads, + causal, + ms, + ) print( f"bench_batch_paged_prefill (page_size={page_size} batch_size={batch_size}, num_heads={num_heads}, seq_len={seq_len}, causal={causal}, head_dim={head_dim}), fa2-template: {flops(sm80_ms):.3f} TFLOPs/s, fa3-template: {flops(sm90_ms):.3f} TFLOPs/s" diff --git a/benchmarks/bench_hopper_fp8_attention.py b/benchmarks/bench_hopper_fp8_attention.py index 34d71d7f9e..75b02024d6 100644 --- a/benchmarks/bench_hopper_fp8_attention.py +++ b/benchmarks/bench_hopper_fp8_attention.py @@ -1,43 +1,88 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + import numpy as np import torch import flashinfer -from flashinfer.testing.utils import bench_gpu_time +from flashinfer.testing.utils import ( + bench_gpu_time, + attention_tflops_per_sec_with_actual_seq_lens, +) -def bench_single_prefill(seq_len, num_heads, causal, head_dim): - num_qo_heads = num_kv_heads = num_heads - q = torch.randn(seq_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda") - k = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda") - v = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda") - - sm80_ms, sm90_ms = ( - np.median( - bench_gpu_time( - lambda: flashinfer.single_prefill_with_kv_cache_return_lse( - q, k, v, causal=causal, backend=backend - ), - dry_run_time_ms=100, - repeat_time_ms=1000, - ) - ) - for backend in ["fa2", "fa3"] +def per_head_symmetric_quant(x, quant_dtype): + """Per-head symmetric quantization to FP8.""" + o_min_val, o_max_val = ( + (-448.0, 448.0) if quant_dtype == torch.float8_e4m3fn else (-57344, 57344) ) + x_max_val = x.abs().amax(dim=(0, 2)).to(dtype=torch.float32) + s_out = torch.clamp(x_max_val / o_max_val, min=1e-6) + s_out_broadcast = s_out.view(1, -1, 1) + q_x_out = torch.clamp(x / s_out_broadcast, min=o_min_val, max=o_max_val).to( + dtype=quant_dtype + ) + return q_x_out, s_out + + +def bench_fp8_single_prefill( + seq_len, num_heads, causal, head_dim, dtype=torch.float8_e4m3fn +): + """Benchmark FP8 single prefill attention.""" + num_qo_heads = num_kv_heads = num_heads - q = torch.randn( + # Create FP16 tensors first, then quantize + q_fp16 = torch.randn( seq_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda" - ).to(dtype=torch.float8_e4m3fn) - k = torch.randn( + ) + k_fp16 = torch.randn( seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda" - ).to(dtype=torch.float8_e4m3fn) - v = torch.randn( + ) + v_fp16 = torch.randn( seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda" - ).to(dtype=torch.float8_e4m3fn) + ) + + # Quantize to FP8 + q_fp8, s_q = per_head_symmetric_quant(q_fp16, dtype) + k_fp8, s_k = per_head_symmetric_quant(k_fp16, dtype) + v_fp8, s_v = per_head_symmetric_quant(v_fp16, dtype) + + # FP16 baseline (fa3) + fp16_ms = np.median( + bench_gpu_time( + lambda: flashinfer.single_prefill_with_kv_cache_return_lse( + q_fp16, k_fp16, v_fp16, causal=causal, backend="fa3" + ), + dry_run_time_ms=100, + repeat_time_ms=1000, + ) + ) - fp8_sm90_ms = np.median( + # FP8 (fa3) + fp8_ms = np.median( bench_gpu_time( lambda: flashinfer.single_prefill_with_kv_cache_return_lse( - q, k, v, causal=causal, backend="fa3", o_dtype=torch.half + q_fp8, + k_fp8, + v_fp8, + causal=causal, + backend="fa3", + scale_q=s_q, + scale_k=s_k, + scale_v=s_v, ), dry_run_time_ms=100, repeat_time_ms=1000, @@ -45,13 +90,233 @@ def bench_single_prefill(seq_len, num_heads, causal, head_dim): ) def flops(ms): - if causal: - return seq_len * seq_len * num_qo_heads * head_dim * 2 / ms / 1e9 - else: - return seq_len * seq_len * num_qo_heads * head_dim * 4 / ms / 1e9 + return attention_tflops_per_sec_with_actual_seq_lens( + torch.tensor([seq_len]), + torch.tensor([seq_len]), + head_dim, + head_dim, + num_qo_heads, + causal, + ms, + ) print( - f"bench_single_prefill (seq_len={seq_len}, num_heads={num_heads}, causal={causal}, head_dim={head_dim}), fa2-template: {flops(sm80_ms):.3f} TFLOPs/s, fa3-template: {flops(sm90_ms):.3f} TFLOPs/s, fa3-fp8: {flops(fp8_sm90_ms):.3f} TFLOPs/s" + f"bench_fp8_single_prefill (seq_len={seq_len}, num_heads={num_heads}, causal={causal}, head_dim={head_dim}), " + f"fp16: {flops(fp16_ms):.3f} TFLOPs/s ({fp16_ms:.3f}ms), " + f"fp8: {flops(fp8_ms):.3f} TFLOPs/s ({fp8_ms:.3f}ms), " + f"speedup: {fp16_ms / fp8_ms:.2f}x" + ) + + +def bench_fp8_batch_ragged_prefill( + batch_size, num_heads, seq_len, causal, head_dim, dtype=torch.float8_e4m3fn +): + """Benchmark FP8 batch ragged prefill attention.""" + num_qo_heads = num_kv_heads = num_heads + total_len = batch_size * seq_len + + # Create FP16 tensors first + q_fp16 = torch.randn( + total_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda" + ) + k_fp16 = torch.randn( + total_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda" + ) + v_fp16 = torch.randn( + total_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda" + ) + + # Quantize to FP8 + q_fp8, s_q = per_head_symmetric_quant(q_fp16, dtype) + k_fp8, s_k = per_head_symmetric_quant(k_fp16, dtype) + v_fp8, s_v = per_head_symmetric_quant(v_fp16, dtype) + + qo_indptr = torch.arange( + 0, total_len + 1, seq_len, dtype=torch.int32, device="cuda" + ) + kv_indptr = torch.arange( + 0, total_len + 1, seq_len, dtype=torch.int32, device="cuda" + ) + + # FP16 wrapper + fp16_wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper( + torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device="cuda"), + kv_layout="NHD", + backend="fa3", + ) + fp16_wrapper.plan( + qo_indptr, kv_indptr, num_qo_heads, num_kv_heads, head_dim, causal=causal + ) + + # FP8 wrapper + fp8_wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper( + torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device="cuda"), + kv_layout="NHD", + backend="fa3", + ) + fp8_wrapper.plan( + qo_indptr, + kv_indptr, + num_qo_heads, + num_kv_heads, + head_dim, + q_data_type=dtype, + kv_data_type=dtype, + o_data_type=torch.half, + causal=causal, + ) + + fp16_ms = np.median( + bench_gpu_time( + lambda: fp16_wrapper.run(q_fp16, k_fp16, v_fp16), + dry_run_time_ms=100, + repeat_time_ms=1000, + ) + ) + + fp8_ms = np.median( + bench_gpu_time( + lambda: fp8_wrapper.run(q_fp8, k_fp8, v_fp8, s_q, s_k, s_v), + dry_run_time_ms=100, + repeat_time_ms=1000, + ) + ) + + def flops(ms): + return attention_tflops_per_sec_with_actual_seq_lens( + torch.full((batch_size,), seq_len), + torch.full((batch_size,), seq_len), + head_dim, + head_dim, + num_qo_heads, + causal, + ms, + ) + + print( + f"bench_fp8_batch_ragged_prefill (batch_size={batch_size}, num_heads={num_heads}, seq_len={seq_len}, causal={causal}, head_dim={head_dim}), " + f"fp16: {flops(fp16_ms):.3f} TFLOPs/s ({fp16_ms:.3f}ms), " + f"fp8: {flops(fp8_ms):.3f} TFLOPs/s ({fp8_ms:.3f}ms), " + f"speedup: {fp16_ms / fp8_ms:.2f}x" + ) + + +def bench_fp8_batch_paged_prefill( + page_size, + batch_size, + num_heads, + seq_len, + causal, + head_dim, + dtype=torch.float8_e4m3fn, +): + """Benchmark FP8 batch paged prefill attention.""" + num_qo_heads = num_kv_heads = num_heads + total_qo_len = batch_size * seq_len + num_pages = batch_size * seq_len // page_size + + # Create FP16 tensors first + q_fp16 = torch.randn( + total_qo_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda" + ) + # Paged KV cache: (num_pages, page_size, num_heads, head_dim) + k_fp16 = torch.randn( + num_pages, page_size, num_kv_heads, head_dim, dtype=torch.half, device="cuda" + ) + v_fp16 = torch.randn( + num_pages, page_size, num_kv_heads, head_dim, dtype=torch.half, device="cuda" + ) + + # Quantize to FP8 + q_fp8, s_q = per_head_symmetric_quant(q_fp16, dtype) + # For paged KV, reshape to (total_tokens, num_heads, head_dim) for quantization + k_flat = k_fp16.view(-1, num_kv_heads, head_dim) + v_flat = v_fp16.view(-1, num_kv_heads, head_dim) + k_fp8_flat, s_k = per_head_symmetric_quant(k_flat, dtype) + v_fp8_flat, s_v = per_head_symmetric_quant(v_flat, dtype) + k_fp8 = k_fp8_flat.view(num_pages, page_size, num_kv_heads, head_dim) + v_fp8 = v_fp8_flat.view(num_pages, page_size, num_kv_heads, head_dim) + + qo_indptr = torch.arange( + 0, total_qo_len + 1, seq_len, dtype=torch.int32, device="cuda" + ) + kv_indptr = torch.arange( + 0, num_pages + 1, seq_len // page_size, dtype=torch.int32, device="cuda" + ) + kv_indices = torch.arange(0, num_pages, dtype=torch.int32, device="cuda") + last_page_len = torch.ones(batch_size, dtype=torch.int32, device="cuda") * page_size + + # FP16 wrapper + fp16_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device="cuda"), + kv_layout="NHD", + backend="fa3", + ) + fp16_wrapper.plan( + qo_indptr, + kv_indptr, + kv_indices, + last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + causal=causal, + ) + + # FP8 wrapper + fp8_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device="cuda"), + kv_layout="NHD", + backend="fa3", + ) + fp8_wrapper.plan( + qo_indptr, + kv_indptr, + kv_indices, + last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + q_data_type=dtype, + kv_data_type=dtype, + o_data_type=torch.half, + causal=causal, + ) + + fp16_ms = np.median( + bench_gpu_time( + lambda: fp16_wrapper.run(q_fp16, (k_fp16, v_fp16)), + dry_run_time_ms=100, + repeat_time_ms=1000, + ) + ) + + fp8_ms = np.median( + bench_gpu_time( + lambda: fp8_wrapper.run(q_fp8, (k_fp8, v_fp8), s_q, s_k, s_v), + dry_run_time_ms=100, + repeat_time_ms=1000, + ) + ) + + def flops(ms): + return attention_tflops_per_sec_with_actual_seq_lens( + torch.full((batch_size,), seq_len), + torch.full((batch_size,), seq_len), + head_dim, + head_dim, + num_qo_heads, + causal, + ms, + ) + + print( + f"bench_fp8_batch_paged_prefill (page_size={page_size}, batch_size={batch_size}, num_heads={num_heads}, seq_len={seq_len}, causal={causal}, head_dim={head_dim}), " + f"fp16: {flops(fp16_ms):.3f} TFLOPs/s ({fp16_ms:.3f}ms), " + f"fp8: {flops(fp8_ms):.3f} TFLOPs/s ({fp8_ms:.3f}ms), " + f"speedup: {fp16_ms / fp8_ms:.2f}x" ) @@ -62,8 +327,30 @@ def flops(ms): print("Current benchmark targets capability (9, 0). Returning...") exit() - for seq_len in [4096, 8192, 16384]: - for num_heads in [24, 32]: - for causal in [True, False]: - for head_dim in [64, 128, 256]: - bench_single_prefill(seq_len, num_heads, causal, head_dim) + # Skip single prefill for now due to compilation issues + # print("=" * 80) + # print("FP8 Single Prefill Benchmarks") + # print("=" * 80) + # for head_dim in [128, 256]: + # for seq_len in [1024, 4096, 8192]: + # bench_fp8_single_prefill(seq_len, 32, True, head_dim) + + print() + print("=" * 80) + print("FP8 Batch Ragged Prefill Benchmarks") + print("=" * 80) + for head_dim in [128, 256]: + bench_fp8_batch_ragged_prefill(128, 32, 1024, True, head_dim) + bench_fp8_batch_ragged_prefill(64, 32, 2048, True, head_dim) + bench_fp8_batch_ragged_prefill(32, 32, 4096, True, head_dim) + bench_fp8_batch_ragged_prefill(16, 32, 8192, True, head_dim) + + print() + print("=" * 80) + print("FP8 Batch Paged Prefill Benchmarks") + print("=" * 80) + for head_dim in [128, 256]: + bench_fp8_batch_paged_prefill(16, 128, 32, 1024, True, head_dim) + bench_fp8_batch_paged_prefill(16, 64, 32, 2048, True, head_dim) + bench_fp8_batch_paged_prefill(16, 32, 32, 4096, True, head_dim) + bench_fp8_batch_paged_prefill(16, 16, 32, 8192, True, head_dim) diff --git a/benchmarks/bench_logging_overhead.py b/benchmarks/bench_logging_overhead.py new file mode 100644 index 0000000000..e67edcfa45 --- /dev/null +++ b/benchmarks/bench_logging_overhead.py @@ -0,0 +1,333 @@ +#!/usr/bin/env python3 +""" +Benchmark script to measure the overhead of API logging at different levels. + +This script creates decorated and undecorated versions of a test function +(torch.matmul) and compares their performance to accurately measure logging overhead. + +Usage: + # Set the logging level before running + export FLASHINFER_LOGLEVEL=3 + python bench_logging_overhead.py + + # Or run with different levels + FLASHINFER_LOGLEVEL=0 python bench_logging_overhead.py + FLASHINFER_LOGLEVEL=1 python bench_logging_overhead.py + FLASHINFER_LOGLEVEL=3 python bench_logging_overhead.py + FLASHINFER_LOGLEVEL=5 python bench_logging_overhead.py + + # Or use the helper script to run all levels + bash benchmark_all_levels.sh +""" + +import os +import sys +import time +import torch +import numpy as np +from typing import List, Tuple + +# Get logging level BEFORE importing flashinfer +LOGGING_LEVEL = int(os.environ.get("FLASHINFER_LOGLEVEL", "0")) +LOG_DEST = os.environ.get("FLASHINFER_LOGDEST", "/tmp/flashinfer_benchmark_log.txt") + +# Import the decorator +from flashinfer.api_logging import flashinfer_api + + +# Create two versions of a test function: +# 1. Undecorated (baseline) +# 2. Decorated (with logging) +def test_matmul_undecorated(A, B): + return torch.matmul(A, B) + + +@flashinfer_api +def test_matmul_decorated(A, B): + return torch.matmul(A, B) + + +class BenchmarkResults: + """Store and display benchmark results.""" + + def __init__(self): + self.undecorated_times = [] + self.decorated_times = [] + + def set_undecorated(self, times: List[float]): + """Set benchmark results for undecorated function.""" + self.undecorated_times = times + + def set_decorated(self, times: List[float]): + """Set benchmark results for decorated function.""" + self.decorated_times = times + + def print_summary(self, logging_level: int): + """Print a summary of benchmark results.""" + print("\n" + "=" * 80) + print("BENCHMARK RESULTS") + print("=" * 80) + + undecorated_mean = np.mean(self.undecorated_times) + undecorated_std = np.std(self.undecorated_times) + + decorated_mean = np.mean(self.decorated_times) + decorated_std = np.std(self.decorated_times) + + overhead_abs = (decorated_mean - undecorated_mean) * 1000 # ms + overhead_pct = ( + ((decorated_mean - undecorated_mean) / undecorated_mean * 100) + if undecorated_mean > 0 + else 0 + ) + + print( + f"\n{'Version':<20} {'Mean (ms)':<12} {'Std (ms)':<12} {'Median (ms)':<12}" + ) + print("-" * 80) + print( + f"{'Undecorated':<20} {undecorated_mean * 1000:<12.4f} {undecorated_std * 1000:<12.4f} {np.median(self.undecorated_times) * 1000:<12.4f}" + ) + print( + f"{'Decorated':<20} {decorated_mean * 1000:<12.4f} {decorated_std * 1000:<12.4f} {np.median(self.decorated_times) * 1000:<12.4f}" + ) + + print("\n" + "=" * 80) + print("OVERHEAD ANALYSIS") + print("=" * 80) + print(f"\nLogging Level: {logging_level}") + print(f"Absolute overhead: {overhead_abs:.4f} ms") + print(f"Relative overhead: {overhead_pct:.2f}%") + + print("\n" + "=" * 80) + print("DETAILED STATISTICS") + print("=" * 80) + + print("\nUndecorated (baseline):") + print(f" Mean: {undecorated_mean * 1000:.4f} ms") + print(f" Median: {np.median(self.undecorated_times) * 1000:.4f} ms") + print(f" Std: {undecorated_std * 1000:.4f} ms") + print(f" Min: {np.min(self.undecorated_times) * 1000:.4f} ms") + print(f" Max: {np.max(self.undecorated_times) * 1000:.4f} ms") + + print("\nDecorated (with logging):") + print(f" Mean: {decorated_mean * 1000:.4f} ms") + print(f" Median: {np.median(self.decorated_times) * 1000:.4f} ms") + print(f" Std: {decorated_std * 1000:.4f} ms") + print(f" Min: {np.min(self.decorated_times) * 1000:.4f} ms") + print(f" Max: {np.max(self.decorated_times) * 1000:.4f} ms") + + +def setup_test_inputs( + batch_size: int = 32, + m: int = 512, + n: int = 512, + k: int = 512, + device: str = "cuda:0", +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Set up test inputs for matmul. + + Parameters + ---------- + batch_size : int + Batch size for the matrix multiplication + m, n, k : int + Matrix dimensions + device : str + Device to use + + Returns + ------- + A, B : torch.Tensor + Input tensors for matrix multiplication + """ + # Create random tensors + A = torch.randn(batch_size, m, k, dtype=torch.float16, device=device) + B = torch.randn(batch_size, k, n, dtype=torch.float16, device=device) + + return A, B + + +def warmup(func, A, B, num_warmup: int = 10): + """Warmup the GPU and JIT compilation.""" + for _ in range(num_warmup): + _ = func(A, B) + torch.cuda.synchronize() + + +def benchmark_function( + func, func_name: str, A, B, num_iterations: int = 100 +) -> List[float]: + """ + Benchmark a specific function. + + Parameters + ---------- + func : callable + Function to benchmark + func_name : str + Name of the function (for display) + A, B : torch.Tensor + Input tensors for matrix multiplication + num_iterations : int + Number of iterations to run + + Returns + ------- + List[float] + List of execution times in seconds + """ + print(f"\nBenchmarking: {func_name}") + print(f" Running {num_iterations} iterations...") + + times = [] + + for _ in range(num_iterations): + # Synchronize before timing + torch.cuda.synchronize() + + # Time the execution + start = time.perf_counter() + _ = func(A, B) + torch.cuda.synchronize() + end = time.perf_counter() + + elapsed = end - start + times.append(elapsed) + + print(f" Complete. Mean time: {np.mean(times) * 1000:.4f} ms") + + return times + + +def main(): + """Main benchmark function.""" + print("=" * 80) + print("FlashInfer API Logging Overhead Benchmark") + print("=" * 80) + + # Display logging configuration + print("\nLogging Configuration:") + print(f" FLASHINFER_LOGLEVEL = {LOGGING_LEVEL}") + print(f" FLASHINFER_LOGDEST = {LOG_DEST}") + + # Get level name + level_names = { + 0: "No logging (zero-overhead)", + 1: "Function name only", + 3: "Name + inputs/outputs + metadata", + 5: "Name + inputs/outputs + metadata + statistics", + } + print(f" Level description: {level_names.get(LOGGING_LEVEL, 'Unknown')}") + + # Check if CUDA is available + if not torch.cuda.is_available(): + print("\nError: CUDA is not available. This benchmark requires a CUDA device.") + exit(1) + + device = "cuda:0" + print(f"\nDevice: {device}") + print(f"Device Name: {torch.cuda.get_device_name(device)}") + + # Setup test inputs + print("\nSetting up test inputs...") + batch_size = 32 + m, n, k = 128, 128, 128 + print(f" Batch size: {batch_size}") + print(f" Matrix dimensions: [{batch_size}, {m}, {k}] @ [{batch_size}, {k}, {n}]") + + A, B = setup_test_inputs(batch_size, m, n, k, device) + + # Benchmark parameters + num_iterations = 100 + print("\nBenchmark parameters:") + print(f" Iterations: {num_iterations}") + print(" Warmup iterations: 10") + + # Clear log file before starting + if os.path.exists(LOG_DEST): + os.remove(LOG_DEST) + + print("\n" + "=" * 80) + print("WARMUP PHASE") + print("=" * 80) + + # Warmup undecorated version + print("\nWarming up undecorated version...") + warmup(test_matmul_undecorated, A, B, num_warmup=10) + print(" Complete.") + + # Warmup decorated version + print("\nWarming up decorated version...") + warmup(test_matmul_decorated, A, B, num_warmup=10) + print(" Complete.") + + print("\n" + "=" * 80) + print("BENCHMARK PHASE") + print("=" * 80) + + # Store results + results = BenchmarkResults() + + # Benchmark undecorated version + undecorated_times = benchmark_function( + test_matmul_undecorated, "Undecorated (baseline)", A, B, num_iterations + ) + results.set_undecorated(undecorated_times) + + # Benchmark decorated version + decorated_times = benchmark_function( + test_matmul_decorated, + f"Decorated (logging level {LOGGING_LEVEL})", + A, + B, + num_iterations, + ) + results.set_decorated(decorated_times) + + # Print summary + results.print_summary(LOGGING_LEVEL) + + # Check log file size + if LOGGING_LEVEL > 0 and os.path.exists(LOG_DEST): + log_size = os.path.getsize(LOG_DEST) + print("\n" + "=" * 80) + print("LOG FILE INFO") + print("=" * 80) + print(f"Log file: {LOG_DEST}") + print(f"Log size: {log_size / 1024:.2f} KB ({log_size} bytes)") + print(f"Iterations logged: {num_iterations}") + print(f"Bytes per iteration: {log_size / num_iterations:.2f}") + + # Cleanup option + cleanup_log = os.environ.get("CLEANUP_LOG", "true").lower() == "true" + if cleanup_log: + os.remove(LOG_DEST) + print("\n Log file removed (set CLEANUP_LOG=false to keep it)") + else: + print(f"\n Log file preserved at {LOG_DEST}") + + print("\n" + "=" * 80) + print("RECOMMENDATIONS") + print("=" * 80) + print("\nTo benchmark other levels, run:") + for level in [0, 1, 3, 5]: + if level != LOGGING_LEVEL: + print(f" FLASHINFER_LOGLEVEL={level} python {sys.argv[0]}") + + print("\n" + "=" * 80) + print("Benchmark complete!") + print("=" * 80) + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + print("\n\nBenchmark interrupted by user.") + except Exception as e: + print(f"\n\nError during benchmark: {e}") + import traceback + + traceback.print_exc() diff --git a/benchmarks/bench_mixed_attention.py b/benchmarks/bench_mixed_attention.py index 85753a71f9..7414a58af0 100644 --- a/benchmarks/bench_mixed_attention.py +++ b/benchmarks/bench_mixed_attention.py @@ -23,7 +23,10 @@ def run_bench( q_lens = torch.tensor(d_qo_lens + p_qo_lens, dtype=torch.int32) seq_lens_blocks = torch.ceil(seq_lens / page_block_size).int() - d_seq_lens_blocks = ( + p_seq_lens_blocks = torch.ceil( + torch.tensor(p_kv_lens, dtype=torch.int32) / page_block_size + ).int() + d_seq_lens_blocks = torch.ceil( torch.tensor(d_kv_lens, dtype=torch.int32) / page_block_size ).int() @@ -31,6 +34,14 @@ def run_bench( kv_indptr = torch.cat( [torch.tensor([0]), torch.cumsum(seq_lens_blocks, 0)], dim=0 ).int() + + p_q_indptr = torch.cat( + [torch.tensor([0]), torch.cumsum(torch.tensor(p_qo_lens), 0)], dim=0 + ).int() + p_kv_indptr = torch.cat( + [torch.tensor([0]), torch.cumsum(p_seq_lens_blocks, 0)], dim=0 + ).int() + d_q_indptr = torch.cat( [torch.tensor([0]), torch.cumsum(torch.tensor(d_qo_lens), 0)], dim=0 ).int() @@ -46,7 +57,7 @@ def run_bench( device, dtype=torch.bfloat16 ) - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device) + workspace_buffer = torch.empty(156 * 1024 * 1024, dtype=torch.uint8, device=device) kv_layout = "NHD" wrapper_old = flashinfer.BatchPrefillWithPagedKVCacheWrapper( @@ -72,7 +83,85 @@ def run_bench( measurements = bench_gpu_time(lambda: wrapper_old.run(q, kv_data)) ms_old = np.median(measurements) + wrapper_persistent = flashinfer.BatchAttention(kv_layout="NHD") + wrapper_persistent.plan( + q_indptr.to(device), + kv_indptr.to(device), + torch.arange(num_blocks, dtype=torch.int32, device=device), + seq_lens.to(device), + num_qo_heads, + num_kv_heads, + head_dim, + head_dim, + page_block_size, + causal=causal, + q_data_type=torch.bfloat16, + kv_data_type=torch.bfloat16, + ) + o_persistent, _ = wrapper_persistent.run(q, kv_data) + measurements_persistent = bench_gpu_time(lambda: wrapper_persistent.run(q, kv_data)) + ms_persistent = np.mean(measurements_persistent) + + # Batched POD Attention + q_d = q[: d_q_indptr[-1]] + kv_d = kv_data[: d_kv_indptr[-1]].unbind(1) + q_p = q[d_q_indptr[-1] :] + kv_p = kv_data[d_kv_indptr[-1] :].unbind(1) + kv_indices_d = torch.arange(0, d_kv_indptr[-1], device=device, dtype=torch.int32) + kv_indices_p = torch.arange(0, p_kv_indptr[-1], device=device, dtype=torch.int32) + + last_page_len_d = (d_seq_lens_blocks - 1) % page_block_size + 1 + last_page_len_p = (p_seq_lens_blocks - 1) % page_block_size + 1 + wrapper_pod = flashinfer.BatchPODWithPagedKVCacheWrapper( + workspace_buffer, + kv_layout=kv_layout, + ) + + wrapper_pod.plan( + # Prefill params + p_q_indptr.to(device), + p_kv_indptr.to(device), + kv_indices_p.to(device), + last_page_len_p, + # Decode params + d_q_indptr.to(device), + d_kv_indptr.to(device), + kv_indices_d.to(device), + last_page_len_d, + # Common params + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + page_size=page_block_size, + q_data_type=torch.bfloat16, + kv_data_type=torch.bfloat16, + ) + o_p_batch, o_d_batch = wrapper_pod.run( + q_p, + kv_p, + q_d, + kv_d, + causal_p=causal, + ) + o_batch_pod = torch.cat([o_d_batch, o_p_batch], dim=0) + + # Verify output matches + torch.testing.assert_close( + o_batch_pod, o, rtol=4e-3, atol=4e-3, msg="Batch POD-Attention decode mismatch!" + ) + measurements = bench_gpu_time( + lambda: wrapper_pod.run( + q_p, + kv_p, + q_d, + kv_d, + causal_p=causal, + ) + ) + ms_batch_pod = np.median(measurements) + if len(p_kv_lens) == 1: + # Single POD attention q_d = q[: d_q_indptr[-1]] kv_d = kv_data[: d_kv_indptr[-1]].unbind(1) q_p = q[d_q_indptr[-1] :] @@ -109,7 +198,7 @@ def run_bench( o_pod = torch.cat([o_d, o_p], dim=0) # Verify output matches torch.testing.assert_close( - o, o_pod, rtol=1e-3, atol=1e-3, msg="POD-Attention output mismatch!" + o, o_pod, rtol=4e-3, atol=4e-3, msg="POD-Attention output mismatch!" ) measurements = bench_gpu_time( lambda: wrapper_pod.run( @@ -123,9 +212,51 @@ def run_bench( ) ) ms_pod = np.median(measurements) + + # Sequential two kernels: single prefill + batch decode (tensor cores) + # Prefill using single_prefill_with_kv_cache + def _run_single_prefill(): + return flashinfer.prefill.single_prefill_with_kv_cache( + q_p, + k_p, + v_p, + causal=causal, + pos_encoding_mode="NONE", + backend="fa2", + ) + + measurements_prefill = bench_gpu_time(lambda: _run_single_prefill()) + ms_prefill = np.median(measurements_prefill) + + # Batch decode using tensor cores + wrapper_decode = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, kv_layout=kv_layout, use_tensor_cores=True + ) + wrapper_decode.plan( + d_kv_indptr.to(device), + kv_indices_d.to(device), + last_page_len_d, + num_qo_heads, + num_kv_heads, + head_dim, + page_block_size, + data_type=torch.bfloat16, + q_data_type=torch.bfloat16, + ) + measurements_decode = bench_gpu_time(lambda: wrapper_decode.run(q_d, kv_d)) + ms_decode = np.median(measurements_decode) + ms_seq_two_kernels = ms_prefill + ms_decode + print(f"Elapsed time (Batched Prefill): {ms_old:.2f} ms") + print(f"Elapsed time (Batched POD Attention): {ms_batch_pod:.2f} ms") if len(p_kv_lens) == 1: print(f"Elapsed time (POD Attention): {ms_pod:.2f} ms") + print(f"Elapsed time (Sequential two kernels): {ms_seq_two_kernels:.2f} ms") + print(f"Elapsed time (Persistent BatchAttention): {ms_persistent:.2f} ms") + print( + f"Batch POD speedup over Persistent BatchAttention: {ms_persistent / ms_batch_pod:.2f}x" + ) + total_bytes = ( q.numel() * q.element_size() + kv_data.numel() * kv_data.element_size() ) @@ -134,9 +265,21 @@ def run_bench( bandwidth_old_gb_s = total_bytes / (ms_old * 1e-3) / (1024**3) print(f"Memory bandwidth (Batched Prefill): {bandwidth_old_gb_s:.2f} GB/s") + bandwidth_batch_pod_gb_s = total_bytes / (ms_batch_pod * 1e-3) / (1024**3) + print( + f"Memory bandwidth (Batched POD Attention): {bandwidth_batch_pod_gb_s:.2f} GB/s" + ) if len(p_kv_lens) == 1: bandwidth_pod_gb_s = total_bytes / (ms_pod * 1e-3) / (1024**3) print(f"Memory bandwidth (POD Attention): {bandwidth_pod_gb_s:.2f} GB/s") + bandwidth_seq_gb_s = total_bytes / (ms_seq_two_kernels * 1e-3) / (1024**3) + print( + f"Memory bandwidth (Sequential two kernels): {bandwidth_seq_gb_s:.2f} GB/s" + ) + bandwidth_persistent_gb_s = total_bytes / (ms_persistent * 1e-3) / (1024**3) + print( + f"Memory bandwidth (Persistent BatchAttention): {bandwidth_persistent_gb_s:.2f} GB/s" + ) if __name__ == "__main__": @@ -144,74 +287,26 @@ def run_bench( torch.random.manual_seed(42) # Irregular sequence lengths for prefill and decode - d_q_len_configs = [[1] * 122, [1] * 128, [1] * 242, [1] * 256] - d_kv_len_configs = [[600] * 122, [10000] * 128, [400] * 242, [8192] * 256] - p_q_configs = [[17] * 1, [10000], [17] * 1, []] - p_kv_configs = [[10000] * 1, [10000], [8192] * 1, []] - - # construct random length testcases - for _ in range(1): - bsz = 256 - stride = 16 - sparsity = 0.05 - - full_kv_len = np.random.randint(1000, 8192, size=bsz) - p_q_lens = [] - p_kv_lens = [] - d_q_lens = [] - d_kv_lens = [] - for i in range(bsz): - if i % stride == 0: - kv_len = full_kv_len[i] - qo_len = stride + 1 - p_q_lens.append(qo_len) - p_kv_lens.append(kv_len) - else: - kv_len = int(full_kv_len[i] * sparsity) - qo_len = 1 - d_q_lens.append(qo_len) - d_kv_lens.append(kv_len) - - p_q_configs.append(p_q_lens) - p_kv_configs.append(p_kv_lens) - d_q_len_configs.append(d_q_lens) - d_kv_len_configs.append(d_kv_lens) - - for _ in range(1): - bsz = 128 - stride = 16 - sparsity = 0.05 - - full_kv_len = np.random.randint(2000, 16000, size=bsz) - p_q_lens = [] - p_kv_lens = [] - d_q_lens = [] - d_kv_lens = [] - - for i in range(bsz): - if i % stride == 0: - kv_len = full_kv_len[i] - qo_len = stride + 1 - p_q_lens.append(qo_len) - p_kv_lens.append(kv_len) - else: - kv_len = int(full_kv_len[i] * sparsity) - qo_len = 1 - d_q_lens.append(qo_len) - d_kv_lens.append(kv_len) - - p_q_configs.append(p_q_lens) - p_kv_configs.append(p_kv_lens) - d_q_len_configs.append(d_q_lens) - d_kv_len_configs.append(d_kv_lens) + d_q_len_configs = [[1] * 128] * 7 + d_kv_len_configs = [ + [2048] * 128, + [2048] * 128, + [2048] * 128, + [2048] * 128, + [4096] * 128, + [8192] * 128, + [8192] * 128, + ] + p_q_configs = [[512], [1536], [2048] * 2, [2048], [4096], [4096], [6000]] + p_kv_configs = [[512], [1536], [2048] * 2, [2048], [4096], [4096], [7000]] page_block_size = 1 - num_kv_heads = 4 - num_qo_heads = 28 + num_kv_heads = 8 + num_qo_heads = 32 head_dim = 128 for idx, (p_q_lens, p_kv_lens, d_q_len, d_kv_len) in enumerate( - zip(p_q_configs, p_kv_configs, d_q_len_configs, d_kv_len_configs) + zip(p_q_configs, p_kv_configs, d_q_len_configs, d_kv_len_configs, strict=True) ): print(f"===== Benchmark {idx + 1}: (kv_len, qo_len) set =====") run_bench( diff --git a/benchmarks/bench_mm_fp8.py b/benchmarks/bench_mm_fp8.py index a4df76ebd9..7661d5a57e 100644 --- a/benchmarks/bench_mm_fp8.py +++ b/benchmarks/bench_mm_fp8.py @@ -67,11 +67,12 @@ def bench_mm_fp8(m, n, k, in_dtype, out_dtype): input_fp8, prepared_weights, global_scale, - res, + out=res, ), - dry_run_time_ms=500, - repeat_time_ms=2500, + dry_run_time_ms=25, + repeat_time_ms=100, # 100ms should be enough for low latency kernels that run within 100 usec use_cuda_graph=True, + enable_cupti=True, ) ms = np.median(measurements) tflops_per_second = 2 * m * n * k * 1e-9 / ms diff --git a/benchmarks/bench_rope_quantize_fp8_append_cache.py b/benchmarks/bench_rope_quantize_fp8_append_cache.py new file mode 100644 index 0000000000..3119b9fef8 --- /dev/null +++ b/benchmarks/bench_rope_quantize_fp8_append_cache.py @@ -0,0 +1,342 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import os +import sys +import argparse +import flashinfer +import numpy as np +import torch +from flashinfer.testing.utils import bench_gpu_time_with_cudagraph +from flashinfer.utils import get_gpu_memory_bandwidth + +# Add the project root to Python path to import test helpers +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) +from tests.test_helpers.rope_reference import RotaryEmbedding + + +def benchmark_config( + config_name, + num_tokens, + batch_size=4, + page_size=16, + enable_pdl=False, + single_run=False, +): + """Benchmark a specific attention configuration with paged KV cache append.""" + input_dtype = torch.bfloat16 + device = "cuda" + quant_dtype = torch.float8_e4m3fn + + # Configuration-specific parameters + if config_name == "mla": + # MLA: DeepSeek-style multi-latent attention + num_qo_heads, num_kv_heads = 128, 1 + rope_dim, no_rope_dim = 64, 512 + elif config_name == "gqa": + # GQA: Grouped-query attention (e.g., Llama-style) + num_qo_heads, num_kv_heads = 32, 8 + rope_dim, no_rope_dim = 64, 64 + elif config_name == "mha": + # MHA: Standard multi-head attention + num_qo_heads, num_kv_heads = 32, 32 + rope_dim, no_rope_dim = 64, 64 + else: + raise ValueError(f"Unknown config: {config_name}") + + head_dim = rope_dim + no_rope_dim + + # Create input tensors + if config_name == "mla": + # MLA: 2D K tensors (shared) + q_rope = torch.randn( + num_tokens, num_qo_heads, rope_dim, dtype=input_dtype, device=device + ) + q_nope = torch.randn( + num_tokens, num_qo_heads, no_rope_dim, dtype=input_dtype, device=device + ) + k_rope = torch.randn(num_tokens, rope_dim, dtype=input_dtype, device=device) + k_nope = torch.randn(num_tokens, no_rope_dim, dtype=input_dtype, device=device) + v = None + else: + # GQA/MHA: 3D K/V tensors + q_rope = torch.randn( + num_tokens, num_qo_heads, rope_dim, dtype=input_dtype, device=device + ) + q_nope = torch.randn( + num_tokens, num_qo_heads, no_rope_dim, dtype=input_dtype, device=device + ) + k_rope = torch.randn( + num_tokens, num_kv_heads, rope_dim, dtype=input_dtype, device=device + ) + k_nope = torch.randn( + num_tokens, num_kv_heads, no_rope_dim, dtype=input_dtype, device=device + ) + v = torch.randn( + num_tokens, num_kv_heads, head_dim, dtype=input_dtype, device=device + ) + + # Create RoPE reference for cos/sin cache (ensure it covers this run) + max_seq_len = int(num_tokens) + rope_ref = RotaryEmbedding( + head_size=head_dim, + rotary_dim=rope_dim, + max_position_embeddings=max_seq_len, + base=10000, + is_neox_style=False, + dtype=input_dtype, + device=device, + ) + pos_ids = torch.arange(num_tokens, device=device, dtype=torch.int32) + + # Build paged metadata (single request with all tokens) + kv_append_length = torch.tensor( + [num_tokens] + [0] * (batch_size - 1), dtype=torch.int32, device=device + ) + kv_append_indptr = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=device), + torch.cumsum(kv_append_length, dim=0), + ] + ) + num_pages_per_req = torch.tensor( + [(num_tokens + page_size - 1) // page_size] + [0] * (batch_size - 1), + dtype=torch.int32, + device=device, + ) + kv_page_indptr = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=device), + torch.cumsum(num_pages_per_req, dim=0), + ] + ) + kv_page_indices = torch.arange( + kv_page_indptr[-1].item(), dtype=torch.int32, device=device + ) + kv_last_page_len = torch.tensor( + [num_tokens % page_size if num_tokens % page_size != 0 else page_size] + + [0] * (batch_size - 1), + dtype=torch.int32, + device=device, + ) + + # Get batch_indices and positions + seq_lens = flashinfer.get_seq_lens(kv_page_indptr, kv_last_page_len, page_size) + batch_indices, positions = flashinfer.get_batch_indices_positions( + kv_append_indptr, seq_lens, num_tokens + ) + + # Allocate caches + max_pages = kv_page_indptr[-1].item() + + if config_name == "mla": + ckv_cache = torch.zeros( + max_pages, page_size, no_rope_dim, dtype=quant_dtype, device=device + ) + kpe_cache = torch.zeros( + max_pages, page_size, rope_dim, dtype=quant_dtype, device=device + ) + paged_kv_cache = (ckv_cache, kpe_cache) + else: + # GQA/MHA: use NHD layout + k_cache = torch.zeros( + max_pages, + page_size, + num_kv_heads, + head_dim, + dtype=quant_dtype, + device=device, + ) + v_cache = torch.zeros( + max_pages, + page_size, + num_kv_heads, + head_dim, + dtype=quant_dtype, + device=device, + ) + paged_kv_cache = (k_cache, v_cache) + + run_idx = 0 + + def execute(): + if single_run: + import torch.cuda.nvtx as nvtx + + nvtx.range_push("rope_append") + nonlocal run_idx + run_idx += 1 + + flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache( + q_rope=q_rope, + k_rope=k_rope, + q_nope=q_nope, + k_nope=k_nope, + v=v, + cos_sin_cache=rope_ref.cos_sin_cache, + pos_ids=pos_ids, + paged_kv_cache=paged_kv_cache, + kv_indices=kv_page_indices, + kv_indptr=kv_page_indptr, + batch_indices=batch_indices, + positions=positions, + page_size=page_size, + kv_layout="NHD" if config_name != "mla" else "NHD", + quantize_dtype=quant_dtype, + quant_scale_q=1.0, + quant_scale_kv=1.0, + is_neox=False, + enable_pdl=enable_pdl, + ) + if single_run: + # Ensure kernels complete inside the NVTX range for ncu filtering + torch.cuda.synchronize() + nvtx.range_pop() + + if single_run: + execute() + return None, None, None, None, None + measurements = bench_gpu_time_with_cudagraph(execute) + + # Calculate I/O bytes + # Inputs: q_rope, k_rope, q_nope, k_nope, v (if not MLA), cos_sin_cache, pos_ids + io_bytes = ( + q_rope.numel() * q_rope.element_size() + + k_rope.numel() * k_rope.element_size() + + q_nope.numel() * q_nope.element_size() + + k_nope.numel() * k_nope.element_size() + + rope_ref.cos_sin_cache.numel() * rope_ref.cos_sin_cache.element_size() + + pos_ids.numel() * pos_ids.element_size() + ) + + if v is not None: + io_bytes += v.numel() * v.element_size() + + # Outputs: q_rope_out, q_nope_out (FP8), cache writes (FP8) + io_bytes += ( + q_rope.numel() * torch.finfo(quant_dtype).bits // 8 + + q_nope.numel() * torch.finfo(quant_dtype).bits // 8 + ) + + if config_name == "mla": + # MLA writes to ckv_cache and kpe_cache + io_bytes += ( + num_tokens * no_rope_dim * torch.finfo(quant_dtype).bits // 8 + + num_tokens * rope_dim * torch.finfo(quant_dtype).bits // 8 + ) + else: + # GQA/MHA writes to k_cache and v_cache + io_bytes += ( + num_tokens * num_kv_heads * head_dim * torch.finfo(quant_dtype).bits // 8 + + num_tokens * num_kv_heads * head_dim * torch.finfo(quant_dtype).bits // 8 + ) + + # Calculate statistics + ms = np.median(measurements) + min_ms = np.percentile(measurements, 20) + max_ms = np.percentile(measurements, 80) + + # Calculate bandwidth in GB/s + bandwidth_gb_s = io_bytes / ms / 1e6 + + # Calculate TFLOPs (FP operations) + # RoPE: 6 FLOPs per dimension pair (2 muls + 1 sub for real, 2 muls + 1 add for imag) + # For Q: num_tokens * num_qo_heads * (rope_dim/2) pairs * 6 FLOPs + # For K: depends on architecture + q_flops = num_tokens * num_qo_heads * (rope_dim / 2) * 6 + + if config_name == "mla": + # MLA: K is 2D (no head dimension) + k_flops = num_tokens * (rope_dim / 2) * 6 + else: + # GQA/MHA: K is 3D (has head dimension) + k_flops = num_tokens * num_kv_heads * (rope_dim / 2) * 6 + + total_flops = q_flops + k_flops + tflops = ( + total_flops / ms / 1e9 + ) # TFLOPs (operations per ms = operations per second / 1e12) + + return ms, min_ms, max_ms, bandwidth_gb_s, tflops + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--ncu-single", action="store_true", help="Run a single execute() for ncu" + ) + parser.add_argument( + "--config", type=str, default="", help="Config name: mla/gqa/mha" + ) + parser.add_argument("--num-tokens", type=int, default=0) + parser.add_argument("--page-size", type=int, default=16) + parser.add_argument("--enable-pdl", type=int, default=0) + args, unknown = parser.parse_known_args() + + if args.ncu_single: + # Minimal single-run for ncu profiling + cfg = args.config or "mla" + ntok = int(args.num_tokens) + pgsz = int(args.page_size) + en_pdl = bool(int(args.enable_pdl)) + # Force a single execution path + benchmark_config(cfg, ntok, page_size=pgsz, enable_pdl=en_pdl, single_run=True) + sys.exit(0) + + # Get GPU information (for display only) + device = torch.device("cuda:0") + gpu_name = torch.cuda.get_device_name(0) + gpu_peak_bandwidth = get_gpu_memory_bandwidth(device) + print(f"\nDetected GPU: {gpu_name}") + print(f"Theoretical Peak Memory Bandwidth: {gpu_peak_bandwidth:.2f} GB/s") + print() + + # Token counts to benchmark + token_counts = [1, 32, 128, 384, 768, 1024, 2048, 4096, 8192] + + # Helper function to print a table for a specific configuration + def print_config_table(config_name, config_desc): + page_size_to_benchmark = 32 + print(f"\n{'=' * 100}") + print(f" {config_name.upper()}: {config_desc}") + print(f"{'=' * 100}") + + print( + f"{'Tokens':<10} {'Time (ms)':<12} {'BW (GB/s)':<12} {'BW% (Peak)':<14} {'TFLOPs':<12}" + ) + print("-" * 70) + for num_tokens in token_counts: + ms, _, _, bw, tflops = benchmark_config( + config_name, num_tokens, page_size=page_size_to_benchmark + ) + bw_pct = (bw / gpu_peak_bandwidth) * 100 + print( + f"{num_tokens:<10} {ms:<12.5f} {bw:<12.2f} {bw_pct:<14.1f} {tflops:<12.3f}" + ) + + # Print tables for each configuration + print_config_table("mla", "128 Q heads, 1 K head, 64+512 dims (DeepSeek-style)") + print_config_table("gqa", "32 Q heads, 8 K heads, 64+64 dims (Llama-style)") + print_config_table("mha", "32 Q heads, 32 K heads, 64+64 dims (Standard)") + + print("\n" + "=" * 100) + print("Configuration details:") + print(" Page size: 32, Batch size: 4") + print(" Token range: 1 (single decode) → 8192 (large prefill)") + print(f" GPU: {gpu_name}") + print(f" Theoretical Peak Memory Bandwidth: {gpu_peak_bandwidth:.2f} GB/s") + print(" BW% calculated as: (achieved_bandwidth / peak_bandwidth) * 100") + print("=" * 100) diff --git a/benchmarks/bench_sampling.py b/benchmarks/bench_sampling.py index 2eb2de3875..cc2406e43f 100644 --- a/benchmarks/bench_sampling.py +++ b/benchmarks/bench_sampling.py @@ -220,6 +220,86 @@ def main(): f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s" ) + print("---") + print("top-p renorm probs") + for vocab_size in [128512]: + for batch_size in [1, 16, 32, 64, 128, 256, 512]: + torch.manual_seed(42) + for distrib in [ + normal_distribution(1), + normal_distribution(5), + gumbel_distribution(0.1), + gumbel_distribution(1), + ]: + for p in [0.1, 0.5, 0.9]: + logits = distrib((batch_size, vocab_size), device="cuda") + probs = torch.softmax(logits, dim=-1) + measurements = bench_gpu_time( + lambda: flashinfer.sampling.top_p_renorm_probs(probs, p), + dry_run_time_ms=100, + repeat_time_ms=1000, + ) + ms = np.median(measurements) + + io = probs.numel() * probs.element_size() * 2 + bandwidth = io * 1e-6 / ms + print( + f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, p: {p}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s" + ) + + print("---") + print("top-k renorm probs") + for vocab_size in [128512]: + for batch_size in [1, 16, 32, 64, 128, 256, 512]: + torch.manual_seed(42) + for distrib in [ + normal_distribution(1), + normal_distribution(5), + gumbel_distribution(0.1), + gumbel_distribution(1), + ]: + for k in [10, 100, 1000, 5000]: + logits = distrib((batch_size, vocab_size), device="cuda") + probs = torch.softmax(logits, dim=-1) + measurements = bench_gpu_time( + lambda: flashinfer.sampling.top_k_renorm_probs(probs, k), + dry_run_time_ms=100, + repeat_time_ms=1000, + ) + ms = np.median(measurements) + + io = probs.numel() * probs.element_size() * 2 + bandwidth = io * 1e-6 / ms + print( + f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, k: {k}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s" + ) + + print("---") + print("top-k mask logits") + for vocab_size in [128512]: + for batch_size in [1, 16, 32, 64, 128, 256, 512]: + torch.manual_seed(42) + for distrib in [ + normal_distribution(1), + normal_distribution(5), + gumbel_distribution(0.1), + gumbel_distribution(1), + ]: + for k in [10, 100, 1000, 5000]: + logits = distrib((batch_size, vocab_size), device="cuda") + measurements = bench_gpu_time( + lambda: flashinfer.sampling.top_k_mask_logits(logits, k), + dry_run_time_ms=100, + repeat_time_ms=1000, + ) + ms = np.median(measurements) + + io = logits.numel() * logits.element_size() * 2 + bandwidth = io * 1e-6 / ms + print( + f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, k: {k}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s" + ) + if __name__ == "__main__": main() diff --git a/benchmarks/bench_softmax.py b/benchmarks/bench_softmax.py new file mode 100755 index 0000000000..6da8dc9fcb --- /dev/null +++ b/benchmarks/bench_softmax.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python3 +""" +Benchmark script comparing torch.softmax vs flashinfer.softmax performance. +Creates a heatmap showing speedup across different batch sizes and hidden dimensions. +""" + +import numpy as np +import torch +import matplotlib.pyplot as plt +import seaborn as sns +from typing import List, Tuple +import flashinfer +from flashinfer.testing.utils import bench_gpu_time + + +@torch.inference_mode() +def benchmark_torch_softmax(logits: torch.Tensor) -> float: + """Benchmark torch's native softmax.""" + measurements = bench_gpu_time( + lambda: torch.softmax(logits, dim=-1), + dry_run_time_ms=100, + repeat_time_ms=1000, + ) + return np.median(measurements) + + +@torch.inference_mode() +def benchmark_flashinfer_softmax(logits: torch.Tensor) -> float: + """Benchmark flashinfer's softmax.""" + measurements = bench_gpu_time( + lambda: flashinfer.sampling.softmax(logits, temperature=None, enable_pdl=False), + dry_run_time_ms=100, + repeat_time_ms=1000, + ) + return np.median(measurements) + + +def run_benchmark( + batch_sizes: List[int], hidden_sizes: List[int] +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Run benchmarks for all combinations of batch_size and hidden_size. + + Returns: + torch_times: 2D array of torch softmax times (ms) + flashinfer_times: 2D array of flashinfer softmax times (ms) + speedups: 2D array of speedup ratios (torch_time / flashinfer_time) + """ + n_batch = len(batch_sizes) + n_hidden = len(hidden_sizes) + + torch_times = np.zeros((n_batch, n_hidden)) + flashinfer_times = np.zeros((n_batch, n_hidden)) + speedups = np.zeros((n_batch, n_hidden)) + + print("Running benchmarks...") + print("=" * 100) + print( + f"{'Batch Size':<12} {'Hidden Size':<12} {'Torch (ms)':<15} " + f"{'FlashInfer (ms)':<18} {'Speedup':<10} {'Bandwidth (GB/s)':<18}" + ) + print("=" * 100) + + for i, batch_size in enumerate(batch_sizes): + for j, hidden_size in enumerate(hidden_sizes): + # Generate random logits + torch.manual_seed(42) + logits = torch.randn( + batch_size, hidden_size, device="cuda", dtype=torch.float32 + ) + + # Benchmark torch softmax + torch_time_ms = benchmark_torch_softmax(logits) + torch_times[i, j] = torch_time_ms + + # Benchmark flashinfer softmax + flashinfer_time_ms = benchmark_flashinfer_softmax(logits) + flashinfer_times[i, j] = flashinfer_time_ms + + # Calculate speedup + speedup = torch_time_ms / flashinfer_time_ms + speedups[i, j] = speedup + + # Calculate effective bandwidth (read + write) + io_bytes = logits.numel() * logits.element_size() * 2 + bandwidth_gb_s = io_bytes * 1e-6 / flashinfer_time_ms + + print( + f"{batch_size:<12} {hidden_size:<12} {torch_time_ms:<15.4f} " + f"{flashinfer_time_ms:<18.4f} {speedup:<10.2f}x {bandwidth_gb_s:<18.2f}" + ) + + print("=" * 100) + return torch_times, flashinfer_times, speedups + + +def plot_heatmap( + speedups: np.ndarray, + batch_sizes: List[int], + hidden_sizes: List[int], + save_path: str = "softmax_speedup_heatmap.png", +): + """Create and save a heatmap of speedup values.""" + # Create figure + fig, ax = plt.subplots(figsize=(12, 8)) + + # Create heatmap + sns.heatmap( + speedups, + annot=True, + fmt=".2f", + cmap="RdYlGn", + center=1.0, + cbar_kws={"label": "Speedup (x)"}, + xticklabels=[f"{h // 1000}K" for h in hidden_sizes], + yticklabels=batch_sizes, + ax=ax, + vmin=0.5, # Adjust color scale + vmax=max(3.0, speedups.max()), # Dynamic upper bound + ) + + ax.set_xlabel("Hidden Size", fontsize=12, fontweight="bold") + ax.set_ylabel("Batch Size", fontsize=12, fontweight="bold") + ax.set_title( + "FlashInfer Softmax Speedup vs PyTorch (Higher is Better)", + fontsize=14, + fontweight="bold", + pad=20, + ) + + plt.tight_layout() + plt.savefig(save_path, dpi=300, bbox_inches="tight") + print(f"\nHeatmap saved to: {save_path}") + + # Also create a performance comparison plot + _, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6)) + + # Plot 2: Speedup trends across batch sizes + for j, hidden_size in enumerate(hidden_sizes): + ax2.plot( + batch_sizes, + speedups[:, j], + marker="o", + label=f"Hidden={hidden_size // 1000}K", + linewidth=2, + ) + + ax2.set_xlabel("Batch Size", fontsize=12, fontweight="bold") + ax2.set_ylabel("Speedup (x)", fontsize=12, fontweight="bold") + ax2.set_title("Speedup vs Batch Size", fontsize=13, fontweight="bold") + ax2.set_xscale("log", base=2) + ax2.grid(True, alpha=0.3) + ax2.legend(fontsize=9) + ax2.axhline(y=1.0, color="red", linestyle="--", alpha=0.5, label="No speedup") + + # Plot 1: Speedup trends across hidden sizes + for i, batch_size in enumerate(batch_sizes[::2]): # Sample every other batch size + idx = i * 2 + ax1.plot( + [h // 1000 for h in hidden_sizes], + speedups[idx, :], + marker="s", + label=f"Batch={batch_size}", + linewidth=2, + ) + + ax1.set_xlabel("Hidden Size (K)", fontsize=12, fontweight="bold") + ax1.set_ylabel("Speedup (x)", fontsize=12, fontweight="bold") + ax1.set_title("Speedup vs Hidden Size", fontsize=13, fontweight="bold") + ax1.grid(True, alpha=0.3) + ax1.legend(fontsize=9) + ax1.axhline(y=1.0, color="red", linestyle="--", alpha=0.5) + + plt.tight_layout() + comparison_path = save_path.replace(".png", "_trends.png") + plt.savefig(comparison_path, dpi=300, bbox_inches="tight") + print(f"Trend plots saved to: {comparison_path}") + + +def main(): + """Main benchmark execution.""" + # Configuration + batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256, 512, 1024] + hidden_sizes = [32000, 64000, 128000, 256000] + + print("=" * 100) + print("FlashInfer vs PyTorch Softmax Benchmark") + print("=" * 100) + print(f"Batch sizes: {batch_sizes}") + print(f"Hidden sizes: {hidden_sizes}") + print(f"Device: {torch.cuda.get_device_name()}") + print("=" * 100) + print() + + # Run benchmarks + _, _, speedups = run_benchmark(batch_sizes, hidden_sizes) + + # Print summary statistics + print("\nSummary Statistics:") + print("=" * 100) + print(f"Average speedup: {np.mean(speedups):.2f}x") + print(f"Median speedup: {np.median(speedups):.2f}x") + print(f"Min speedup: {np.min(speedups):.2f}x") + print(f"Max speedup: {np.max(speedups):.2f}x") + print("=" * 100) + + # Generate heatmap + plot_heatmap(speedups, batch_sizes, hidden_sizes) + + print("\nBenchmark complete!") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py b/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py index 952b479a1d..203faaff82 100644 --- a/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py +++ b/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py @@ -8,13 +8,164 @@ fp4_quantize, mxfp8_quantize, ) -from flashinfer.fused_moe import trtllm_fp4_block_scale_moe +from flashinfer.fused_moe import ( + trtllm_fp4_block_scale_moe, + trtllm_fp8_per_tensor_scale_moe, + trtllm_fp8_block_scale_moe, + WeightLayout, +) from flashinfer.autotuner import autotune from flashinfer.testing.utils import bench_gpu_time -from flashinfer.utils import device_support_pdl, calculate_tile_tokens_dim +from flashinfer.utils import device_support_pdl + +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max +FLOAT4_E2M1_MAX = 6.0 + + +def fp8_quantize(x): + max = x.abs().max().float() + scale = FLOAT8_E4M3_MAX / max + x = (x * scale).to(torch.float8_e4m3fn) + return x, 1.0 / scale + + +def bench_trtllm_gen_fused_moe_autotuner_fp8( + tune_max_num_tokens: Optional[int], + quant_mode: Literal["Fp8-Per-Tensor", "Fp8-Block"], + num_tokens: int, + num_experts: int, + hidden_size: int, + intermediate_size: int, + top_k: int, + warmups: int, + iterations: int, +): + device = torch.device("cuda:0") + enable_pdl = device_support_pdl(device) + routing_logits = torch.rand(num_tokens, num_experts, device=device).to( + torch.float32 + ) + hidden_states = torch.randn(num_tokens, hidden_size, device=device).to( + torch.bfloat16 + ) + routing_bias = torch.randn(num_experts, device="cuda", dtype=torch.bfloat16) + w13 = torch.randn( + num_experts, intermediate_size * 2, hidden_size, device=device + ).to(torch.bfloat16) + w2 = torch.randn(num_experts, hidden_size, intermediate_size, device=device).to( + torch.bfloat16 + ) + + is_block_scale = quant_mode == "Fp8-Block" + if not is_block_scale: + hidden_states, hidden_states_scale = fp8_quantize(hidden_states) + w13, w13_scale = fp8_quantize(w13) + w2, w2_scale = fp8_quantize(w2) + else: + # block scale quantization is too slow, so we use per-tensor quantization for now + hidden_states, hidden_states_scale = fp8_quantize(hidden_states) + w13, w13_scale = fp8_quantize(w13) + w2, w2_scale = fp8_quantize(w2) + hidden_states_scale = torch.full( + (hidden_size // 128, num_tokens), hidden_states_scale.item(), device=device + ) + w13_scale = torch.full( + (num_experts, intermediate_size * 2 // 128, hidden_size // 128), + w13_scale.item(), + device=device, + ) + w2_scale = torch.full( + (num_experts, hidden_size // 128, intermediate_size // 128), + w2_scale.item(), + device=device, + ) + + output1_scale_scalar = ( + torch.tensor([hidden_states_scale * w13_scale] * num_experts, device=device) + if not is_block_scale + else None + ) + output1_scales_gate_scalar = ( + torch.ones(num_experts, device=device, dtype=torch.float32) + if not is_block_scale + else None + ) + output2_scale_scalar = ( + torch.tensor([hidden_states_scale * w2_scale] * num_experts, device=device) + if not is_block_scale + else None + ) + + if is_block_scale: + fn = lambda: trtllm_fp8_block_scale_moe( + routing_logits, + routing_bias, + hidden_states, + hidden_states_scale, + w13, + w13_scale, + w2, + w2_scale, + num_experts, + top_k, + 8, # n_group + 4, # topk_group + intermediate_size, + 0, # local_expert_offset + num_experts, + 2.5, # routed_scaling_factor + RoutingMethodType.DeepSeekV3.value, + True, # use_shuffled_weight + WeightLayout.BlockMajorK.value, # weight_layout + enable_pdl=enable_pdl, + tune_max_num_tokens=num_tokens + if tune_max_num_tokens is None + else tune_max_num_tokens, + ) + else: + fn = lambda: trtllm_fp8_per_tensor_scale_moe( + routing_logits, + None, # routing_bias + hidden_states, + w13, + output1_scale_scalar, + output1_scales_gate_scalar, + w2, + output2_scale_scalar, + num_experts, + top_k, + None, # n_group + None, # topk_group + intermediate_size, + 0, # local_expert_offset + num_experts, + 1.0, # routed_scaling_factor + False, # use_routing_scales_on_input + RoutingMethodType.TopK.value, + enable_pdl, + num_tokens if tune_max_num_tokens is None else tune_max_num_tokens, + ) + + def bench(do_autotune): + with autotune(do_autotune): + fn() + ms_list = bench_gpu_time( + fn, + dry_run_iters=warmups, + repeat_iters=iterations, + ) + median_ms = np.median(ms_list) + return median_ms + ms = bench(do_autotune=False) + ms_tuned = bench(do_autotune=True) + print( + f"num tokens: {num_tokens}, num experts: {num_experts}, hidden size: {hidden_size}, intermediate size: {intermediate_size}, top k: {top_k}" + ) + print(f"No autotune: {ms:.3f} ms; with autotune: {ms_tuned:.3f} ms") -def bench_trtllm_gen_fused_moe_autotuner( + +def bench_trtllm_gen_fused_moe_autotuner_fp4( tune_max_num_tokens: Optional[int], quant_mode: Literal["NvFP4xNvFP4", "MxFP4xMxFP8", "MxFP4xBf16"], num_tokens: int, @@ -39,6 +190,7 @@ def bench_trtllm_gen_fused_moe_autotuner( torch.tensor([448.0 * 6.0], device=device), sf_vec_size=16, sf_use_ue8m0=False, + is_sf_swizzled_layout=False, ) hidden_states_scale = hidden_states_scale.view(torch.float8_e4m3fn).reshape( num_tokens, -1 @@ -99,9 +251,6 @@ def bench_trtllm_gen_fused_moe_autotuner( bias13 = torch.randn(num_experts, intermediate_size * 2, device=device) * 10 bias2 = torch.randn(num_experts, intermediate_size * 2, device=device) * 10 - tile_tokens_dim = calculate_tile_tokens_dim( - num_tokens, num_experts, top_k, 64 if quant_mode == "MxFP4xBf16" else 128 - ) output1_scale_scalar = torch.tensor( [hidden_states_global_scale * w13_global_scale] * num_experts, device=device ) @@ -136,7 +285,6 @@ def bench_trtllm_gen_fused_moe_autotuner( 0, # local_expert_offset num_experts, None, # routed_scaling_factor - tile_tokens_dim, RoutingMethodType.Renormalize.value, True, enable_pdl, @@ -146,12 +294,11 @@ def bench_trtllm_gen_fused_moe_autotuner( ) def bench(do_autotune): - # warmup with autotune(do_autotune): - for _ in range(warmups): - fn() + fn() ms_list = bench_gpu_time( fn, + dry_run_iters=warmups, repeat_iters=iterations, ) median_ms = np.median(ms_list) @@ -171,7 +318,13 @@ def bench(do_autotune): "--quant-mode", type=str, default="MxFP4xMxFP8", - choices=["NvFP4xNvFP4", "MxFP4xMxFP8", "MxFP4xBf16"], + choices=[ + "NvFP4xNvFP4", + "MxFP4xMxFP8", + "MxFP4xBf16", + "Fp8-Per-Tensor", + "Fp8-Block", + ], help="Quantization mode", ) parser.add_argument("--num-tokens", type=int, default=512, help="Number of tokens") @@ -196,14 +349,27 @@ def bench(do_autotune): "--iterations", type=int, default=100, help="Number of benchmark iterations" ) args = parser.parse_args() - bench_trtllm_gen_fused_moe_autotuner( - args.tune_max_num_tokens, - args.quant_mode, - args.num_tokens, - args.num_experts, - args.hidden_size, - args.intermediate_size, - args.top_k, - args.warmups, - args.iterations, - ) + if args.quant_mode in ["Fp8-Per-Tensor", "Fp8-Block"]: + bench_trtllm_gen_fused_moe_autotuner_fp8( + args.tune_max_num_tokens, + args.quant_mode, + args.num_tokens, + args.num_experts, + args.hidden_size, + args.intermediate_size, + args.top_k, + args.warmups, + args.iterations, + ) + else: + bench_trtllm_gen_fused_moe_autotuner_fp4( + args.tune_max_num_tokens, + args.quant_mode, + args.num_tokens, + args.num_experts, + args.hidden_size, + args.intermediate_size, + args.top_k, + args.warmups, + args.iterations, + ) diff --git a/benchmarks/flashinfer_benchmark.py b/benchmarks/flashinfer_benchmark.py index bd02172eb2..330d734221 100644 --- a/benchmarks/flashinfer_benchmark.py +++ b/benchmarks/flashinfer_benchmark.py @@ -79,7 +79,13 @@ def parse_args(line=sys.argv[1:]): "--use_cupti", action="store_true", default=False, - help="Use CUPTI for timing GPU kernels when available.", + help="[DEPRECATED] Use CUPTI for timing GPU kernels. This is now the default behavior.", + ) + parser.add_argument( + "--use_cuda_events", + action="store_true", + default=False, + help="Use CUDA events for timing GPU kernels instead of CUPTI.", ) parser.add_argument( "--refcheck", @@ -155,6 +161,16 @@ def parse_args(line=sys.argv[1:]): if args.generate_repro_command: args.repro_command = "python3 flashinfer_benchmark.py " + " ".join(line) + + # Deprecation warning for use_cupti + if args.use_cupti: + print( + "[WARNING] --use_cupti is deprecated and will be removed in a future release. CUPTI is now enabled by default." + ) + # use_cupti is deprecated and will be removed in a future release. CUPTI is now enabled by default. + # If --use_cuda_events is passed, disable use_cupti + args.use_cupti = not args.use_cuda_events + return args diff --git a/benchmarks/routines/attention.py b/benchmarks/routines/attention.py index bfebc37d4d..320cfbe020 100644 --- a/benchmarks/routines/attention.py +++ b/benchmarks/routines/attention.py @@ -19,6 +19,30 @@ ) +def normalize_backends(backends): + """ + Normalize backend names planned for deprecation and print warnings. + Currently: + - Replaces deprecated 'trtllm-gen-native' with 'trtllm-native'. + + Args: + backends: List of backend names + + Returns: + List of normalized backend names + """ + normalized = [] + for backend in backends: + if backend == "trtllm-gen-native": + print( + "[WARNING] Backend name 'trtllm-gen-native' has been renamed to 'trtllm-native' and will be removed in a future release. " + ) + normalized.append("trtllm-native") + else: + normalized.append(backend) + return normalized + + def run_attention_test(args): """ Run an attention test. @@ -66,7 +90,8 @@ def parse_attention_args(line, parser): "cudnn", "cutlass", "trtllm-gen", - "trtllm-gen-native", + "trtllm-native", + "trtllm-gen-native", # Deprecated, will be removed in future ], help="Kernel backends to test. Default: fa2", ) @@ -151,6 +176,10 @@ def parse_attention_args(line, parser): ) args = parser.parse_args(line) + + # Normalize backend names (handle deprecated names) + args.backends = normalize_backends(args.backends) + if args.verbose >= 1: print(f"[INFO] {args = }") return args @@ -185,7 +214,7 @@ def sample_actual_seq_lens(max_seqlen, batch_size, device, random_actual_seq_len def testBatchDecodeWithPagedKVCacheWrapper(args): """ Test BatchDecodeWithPagedKVCacheWrapper API and equivalent cuDNN API. - Supports fa2, fa2_tc, cudnn, trtllm-gen, trtllm-gen-native backends. + Supports fa2, fa2_tc, cudnn, trtllm-gen, trtllm-native backends. This test: 1. Creates paged KV cache and query tensors @@ -367,7 +396,7 @@ def testBatchDecodeWithPagedKVCacheWrapper(args): # Now initialize the page tables block_tables = torch.tensor( [ - [k + i * num_pages_per_seq for k in range(num_pages_per_seq)] + [k + i * num_pages_per_seq for k in torch.randperm(num_pages_per_seq)] for i in range(batch_size) ], dtype=torch.int, @@ -392,11 +421,7 @@ def testBatchDecodeWithPagedKVCacheWrapper(args): for i in range(len(kv_indptr) - 1): start_idx = kv_indptr[i] end_idx = kv_indptr[i + 1] - kv_indices[start_idx:end_idx] = torch.arange( - i * num_pages_per_seq, - i * num_pages_per_seq + (end_idx - start_idx), - device=device, - ) + kv_indices[start_idx:end_idx] = block_tables[i, : end_idx - start_idx] kv_last_page_len = ( torch.where( @@ -490,7 +515,7 @@ def run_backend_wrapper(backend): batch_offsets_q=ragged_q, batch_offsets_o=ragged_q, ) - elif backend == "trtllm-gen-native": + elif backend == "trtllm-native": return flashinfer.decode.trtllm_batch_decode_with_kv_cache( query=q.contiguous(), kv_cache=kv_cache, @@ -508,6 +533,8 @@ def run_backend_wrapper(backend): has_reference_output = False # Iterate over each backend: for cur_backend in backends: + # Clear workspace buffer to prevent unexpected interactions between backends. + workspace_buffer.zero_() if run_refcheck: outputs[cur_backend] = run_backend_wrapper(cur_backend).detach().clone() if cur_backend == "fa2": @@ -612,7 +639,7 @@ def run_backend_wrapper(backend): def testBatchPrefillWithPagedKVCacheWrapper(args): """ Test BatchPrefillWithPagedKVCacheWrapper API and equivalent cuDNN API. - Supports fa2, fa3, trtllm-gen, trtllm-gen-native, and cudnn backends. + Supports fa2, fa3, trtllm-gen, trtllm-native, and cudnn backends. This test: 1. Creates paged KV cache and query tensors for prefill @@ -695,13 +722,13 @@ def testBatchPrefillWithPagedKVCacheWrapper(args): remove_trtllm = True if remove_trtllm: backends.remove("trtllm-gen") - if "trtllm-gen-native" in backends: + if "trtllm-native" in backends: remove_trtllm_native = False if not causal: - print("[INFO] trtllm-gen-native backend currently requires causal = True") + print("[INFO] trtllm-native backend currently requires causal = True") remove_trtllm_native = True if remove_trtllm_native: - backends.remove("trtllm-gen-native") + backends.remove("trtllm-native") if "cutlass" in backends: print("[INFO] CUTLASS backend does not support prefill. Skipping.") @@ -806,7 +833,7 @@ def testBatchPrefillWithPagedKVCacheWrapper(args): # Now initialize the page tables block_tables = torch.tensor( [ - [k + i * num_pages_per_seq for k in range(num_pages_per_seq)] + [k + i * num_pages_per_seq for k in torch.randperm(num_pages_per_seq)] for i in range(batch_size) ], dtype=torch.int, @@ -856,11 +883,7 @@ def testBatchPrefillWithPagedKVCacheWrapper(args): for i in range(len(kv_indptr) - 1): start_idx = kv_indptr[i] end_idx = kv_indptr[i + 1] - kv_indices[start_idx:end_idx] = torch.arange( - i * num_pages_per_seq, - i * num_pages_per_seq + (end_idx - start_idx), - device=device, - ) + kv_indices[start_idx:end_idx] = block_tables[i, : end_idx - start_idx] kv_last_page_len = ( torch.where( actual_seq_lens_kv_device.flatten() % page_size == 0, @@ -953,7 +976,7 @@ def run_backend_wrapper(backend): batch_offsets_q=q_indptr, batch_offsets_o=q_indptr, )[0] - elif backend == "trtllm-gen-native": + elif backend == "trtllm-native": return flashinfer.prefill.trtllm_batch_context_with_kv_cache( query=q, kv_cache=kv_cache, @@ -975,6 +998,8 @@ def run_backend_wrapper(backend): has_reference_output = False # Iterate over each backend: for cur_backend in backends: + # Clear workspace buffer to prevent unexpected interactions between backends. + workspace_buffer.zero_() if run_refcheck: outputs[cur_backend] = run_backend_wrapper(cur_backend).detach().clone() if cur_backend == "fa2": @@ -1174,21 +1199,21 @@ def testBatchPrefillWithRaggedKVCacheWrapper(args): remove_trtllm = True if remove_trtllm: backends.remove("trtllm-gen") - if "trtllm-gen-native" in backends: + if "trtllm-native" in backends: remove_trtllm_native = False if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [ torch.float8_e4m3fn, torch.float8_e5m2, ]: - print("[INFO] trtllm-gen-native backend does not support FP8. Skipping.") + print("[INFO] trtllm-native backend does not support FP8. Skipping.") remove_trtllm_native = True if not (head_dim_qk == 192 and head_dim_vo == 128): print( - "[INFO] trtllm-gen-native backend requires head_dim_qk == 192 and head_dim_vo == 128" + "[INFO] trtllm-native backend requires head_dim_qk == 192 and head_dim_vo == 128" ) remove_trtllm_native = True if remove_trtllm_native: - backends.remove("trtllm-gen-native") + backends.remove("trtllm-native") if len(backends) == 0: print("[ERROR] No backends to test. Exiting.") @@ -1400,7 +1425,7 @@ def run_backend_wrapper(backend): batch_offsets_stats=batch_offsets_stats, is_cuda_graph_compatible=True, )[0] - elif backend == "trtllm-gen-native": + elif backend == "trtllm-native": return flashinfer.prefill.trtllm_ragged_attention_deepseek( query=q, key=k, @@ -1427,6 +1452,8 @@ def run_backend_wrapper(backend): has_reference_output = False # Iterate over each backend: for cur_backend in backends: + # Clear workspace buffer to prevent unexpected interactions between backends. + workspace_buffer.zero_() if run_refcheck: outputs[cur_backend] = run_backend_wrapper(cur_backend).detach().clone() if cur_backend == "fa2": @@ -1532,7 +1559,7 @@ def run_backend_wrapper(backend): def testBatchMLAPagedAttentionWrapper(args): """ Test BatchMLAPagedAttentionWrapper and equivalent APIs. - Supports fa2, fa3, cutlass, and trtllm-gen-native. + Supports fa2, fa3, cutlass, and trtllm-native. This test: 1. Creates paged query and key-value cache tensors @@ -1628,15 +1655,15 @@ def testBatchMLAPagedAttentionWrapper(args): remove_cutlass = True if remove_cutlass: backends.remove("cutlass") - if "trtllm-gen-native" in backends: + if "trtllm-native" in backends: remove_trtllm_native = False if page_size not in [32, 64]: print( - "[INFO] trtllm-gen-native backend only supports page size 32 or 64. Skipping." + "[INFO] trtllm-native backend only supports page size 32 or 64. Skipping." ) remove_trtllm_native = True if remove_trtllm_native: - backends.remove("trtllm-gen-native") + backends.remove("trtllm-native") if len(backends) == 0: print("[ERROR] No backends to test. Exiting.") return res @@ -1676,7 +1703,7 @@ def testBatchMLAPagedAttentionWrapper(args): # Now initialize the page tables block_tables = torch.tensor( [ - [k + i * num_pages_per_seq for k in range(num_pages_per_seq)] + [k + i * num_pages_per_seq for k in torch.randperm(num_pages_per_seq)] for i in range(batch_size) ], dtype=torch.int, @@ -1723,11 +1750,7 @@ def testBatchMLAPagedAttentionWrapper(args): for i in range(len(kv_indptr) - 1): start_idx = kv_indptr[i] end_idx = kv_indptr[i + 1] - kv_indices[start_idx:end_idx] = torch.arange( - i * num_pages_per_seq, - i * num_pages_per_seq + (end_idx - start_idx), - device=device, - ) + kv_indices[start_idx:end_idx] = block_tables[i, : end_idx - start_idx] sm_scale = 1.0 / ((128 + 64) ** 0.5) # For DeepSeek-R1 workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device) @@ -1801,8 +1824,8 @@ def run_backend_wrapper(backend): page_table=block_tables, return_lse=False, ) - if backend == "trtllm-gen-native": - return flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( + elif backend == "trtllm-native": + return flashinfer.mla.trtllm_batch_decode_with_kv_cache_mla( query=q.unsqueeze(1), kv_cache=kv_cache.unsqueeze(1), workspace_buffer=workspace_buffer, @@ -1822,6 +1845,8 @@ def run_backend_wrapper(backend): has_reference_output = False # Iterate over each backend: for cur_backend in backends: + # Clear workspace buffer to prevent unexpected interactions between backends. + workspace_buffer.zero_() if run_refcheck: outputs[cur_backend] = run_backend_wrapper(cur_backend).detach().clone() if cur_backend == "fa2": diff --git a/benchmarks/routines/flashinfer_benchmark_utils.py b/benchmarks/routines/flashinfer_benchmark_utils.py index fa1a527d17..d5f363839a 100644 --- a/benchmarks/routines/flashinfer_benchmark_utils.py +++ b/benchmarks/routines/flashinfer_benchmark_utils.py @@ -53,7 +53,6 @@ "routed_scaling_factor", "local_expert_offset", "local_num_experts", - "tile_tokens_dim", "routing_method", "use_shuffled_weight", "weight_layout", @@ -162,43 +161,47 @@ def dtype_str_to_torch_dtype(dtype_str): routine_cc_to_supported_backends = { # ATTENTION "BatchDecodeWithPagedKVCacheWrapper": { + # NOTE: trtllm-native calls trtllm_batch_decode_with_kv_cache "7.5": ["fa2"], "8.0": ["fa2", "fa2_tc", "cudnn"], "8.6": ["fa2", "fa2_tc", "cudnn"], "8.9": ["fa2", "fa2_tc", "cudnn"], - "9.0": ["fa2", "fa2_tc", "cudnn"], - "10.0": ["fa2", "fa2_tc", "cudnn", "trtllm-gen", "trtllm-gen-native"], - "10.3": ["fa2", "fa2_tc", "cudnn", "trtllm-gen", "trtllm-gen-native"], - "12.0": ["fa2", "fa2_tc", "cudnn"], + "9.0": ["fa2", "fa2_tc", "cudnn", "trtllm-native"], + "10.0": ["fa2", "fa2_tc", "cudnn", "trtllm-gen", "trtllm-native"], + "10.3": ["fa2", "fa2_tc", "cudnn", "trtllm-gen", "trtllm-native"], + "12.0": ["fa2", "fa2_tc", "cudnn", "trtllm-native"], }, "BatchPrefillWithPagedKVCacheWrapper": { + # NOTE: trtllm-native calls trtllm_batch_context_with_kv_cache "7.5": [], "8.0": ["fa2", "cudnn"], "8.6": ["fa2", "cudnn"], "8.9": ["fa2", "cudnn"], "9.0": ["fa2", "fa3", "cudnn"], - "10.0": ["fa2", "cudnn", "trtllm-gen", "trtllm-gen-native"], - "10.3": ["fa2", "cudnn", "trtllm-gen", "trtllm-gen-native"], + "10.0": ["fa2", "cudnn", "trtllm-gen", "trtllm-native"], + "10.3": ["fa2", "cudnn", "trtllm-gen", "trtllm-native"], "12.0": ["fa2", "cudnn"], }, "BatchPrefillWithRaggedKVCacheWrapper": { + # NOTE: trtllm-native calls trtllm_ragged_attention_deepseek "7.5": [], "8.0": ["fa2", "cudnn"], "8.6": ["fa2", "cudnn"], "8.9": ["fa2", "cudnn"], "9.0": ["fa2", "fa3", "cudnn"], - "10.0": ["fa2", "cudnn", "cutlass", "trtllm-gen-native"], - "10.3": ["fa2", "cudnn", "cutlass", "trtllm-gen-native"], + "10.0": ["fa2", "cudnn", "cutlass", "trtllm-native"], + "10.3": ["fa2", "cudnn", "cutlass", "trtllm-native"], "12.0": ["fa2", "cudnn"], }, "BatchMLAPagedAttentionWrapper": { + # NOTE: trtllm-native calls trtllm_batch_decode_with_kv_cache_mla "7.5": [], "8.0": ["fa2"], "8.6": ["fa2"], "8.9": ["fa2"], "9.0": ["fa2", "fa3"], - "10.0": ["fa2", "cutlass", "trtllm-gen-native"], - "10.3": ["fa2", "cutlass", "trtllm-gen-native"], + "10.0": ["fa2", "cutlass", "trtllm-native"], + "10.3": ["fa2", "cutlass", "trtllm-native"], "12.0": ["fa2"], }, # GEMM @@ -232,16 +235,7 @@ def dtype_str_to_torch_dtype(dtype_str): "10.3": ["cudnn", "cublas", "cutlass"], "12.0": ["cudnn", "cublas"], }, - "mm_fp4": { - "7.5": [], - "8.0": [], - "8.6": [], - "8.9": [], - "9.0": [], - "10.0": ["cudnn", "trtllm", "cutlass"], - "10.3": ["cudnn", "trtllm", "cutlass"], - "12.0": ["cudnn", "cutlass"], - }, + # Note: mm_fp4 uses support checkers to filter backends, so it is not listed here # MOE "trtllm_fp4_block_scale_moe": { "7.5": [], diff --git a/benchmarks/routines/gemm.py b/benchmarks/routines/gemm.py index 17336189d0..9f95f17fb4 100644 --- a/benchmarks/routines/gemm.py +++ b/benchmarks/routines/gemm.py @@ -131,7 +131,7 @@ def parse_gemm_args(line, parser): required=False, nargs="+", default=["cudnn"], - choices=["cudnn", "cublas", "trtllm", "cutlass"], + choices=["cudnn", "cublas", "trtllm", "cutlass", "auto"], help="Kernel backends to test. Default: cudnn", ) parser.add_argument( @@ -790,61 +790,14 @@ def testMmFp4(args): run_refcheck = args.refcheck use_128x4_sf_layout = args.use_128x4_sf_layout use_nvfp4 = args.use_nvfp4 - autotune_supported_backends = ["cutlass", "trtllm"] + autotune_supported_backends = ["cudnn", "cutlass", "trtllm", "auto"] res = [] - backends = filter_backends_by_compute_capability(backends, args.routine, device) - res_dtype = dtype_str_to_torch_dtype(args.out_dtype) if res_dtype not in [torch.bfloat16, torch.float16]: raise ValueError( f"Unsupported res dtype: {res_dtype}. Supported dtypes are bfloat16 and float16." ) - ## Done parsing input arguments - - if "trtllm" in backends: - remove_trtllm = False - if res_dtype == torch.float16: - print("[INFO] trtllm backend does not support float16 output") - remove_trtllm = True - if remove_trtllm: - backends.remove("trtllm") - if not use_nvfp4: - print( - "[INFO] trtllm backend does not support mxfp4 quantization (use_nvfp4=False)" - ) - backends.remove("trtllm") - if "cutlass" in backends: - remove_cutlass = False - if not use_128x4_sf_layout: - print("[INFO] cutlass backend does not support use_128x4_sf_layout=False") - remove_cutlass = True - if not use_nvfp4: - print( - "[INFO] cutlass backend does not support mxfp4 quantization (use_nvfp4=False)" - ) - backends.remove("cutlass") - if remove_cutlass: - backends.remove("cutlass") - if "cudnn" in backends: - remove_cudnn = False - if not use_128x4_sf_layout: - print("[INFO] cudnn backend does not support use_128x4_sf_layout=False") - remove_cudnn = True - if remove_cudnn: - backends.remove("cudnn") - if getattr(args, "autotune", False): - backends_to_remove = [] - for cur_backend in backends: - if cur_backend not in autotune_supported_backends: - print(f"[INFO] {cur_backend} backend does not support autotune") - backends_to_remove.append(cur_backend) - for cur_backend in backends_to_remove: - backends.remove(cur_backend) - - if len(backends) == 0: - print("[ERROR] No backends to test. Exiting.") - return input = torch.randn([m, k], device=device, dtype=torch.bfloat16) mat2 = torch.randn([n, k], device=device, dtype=torch.bfloat16) @@ -886,11 +839,22 @@ def testMmFp4(args): print(f"[VVERBOSE] {mat2_fp4.dtype = }") alpha = 1.0 / (global_sf_input * global_sf_mat2) if use_nvfp4 else None - # res = torch.empty([m, n], device="cuda", dtype=res_dtype) + # Completed preparing inputs. Now programmatically filter backends + block_size = 16 if use_nvfp4 else 32 + backends_to_remove = [] - def run_backend(backend): - if backend in ["cudnn", "trtllm", "cutlass"]: - return flashinfer.gemm.mm_fp4( + for backend in backends: + # Skip autotune check for now (handled separately below) + if ( + getattr(args, "autotune", False) + and backend not in autotune_supported_backends + ): + print(f"[INFO] {backend} backend does not support autotune") + backends_to_remove.append(backend) + continue + + try: + flashinfer.gemm.mm_fp4( a=input_fp4, b=mat2_fp4.T if backend != "trtllm" else mat2_fp4_trtllm.T, a_descale=input_inv_s, @@ -904,6 +868,34 @@ def run_backend(backend): backend=backend, use_nvfp4=use_nvfp4, ) + except Exception as e: + print( + f"[INFO] {backend} backend does not support this configuration: {type(e).__name__}: {e}" + ) + backends_to_remove.append(backend) + + # Remove unsupported backends + for backend in backends_to_remove: + backends.remove(backend) + + if len(backends) == 0: + print("[ERROR] No backends passed validation. Exiting.") + return + + def run_backend(backend): + if backend in ["cudnn", "trtllm", "cutlass", "auto"]: + return flashinfer.gemm.mm_fp4( + a=input_fp4, + b=mat2_fp4.T if backend != "trtllm" else mat2_fp4_trtllm.T, + a_descale=input_inv_s, + b_descale=mat2_inv_s.T if backend != "trtllm" else mat2_inv_s_trtllm.T, + alpha=alpha, + out_dtype=res_dtype, + block_size=block_size, + use_8x4_sf_layout=not use_128x4_sf_layout, + backend=backend, + use_nvfp4=use_nvfp4, + ) else: raise ValueError(f"Unsupported backend: {backend}") @@ -917,12 +909,11 @@ def run_backend(backend): args.dry_run_iters if args.dry_run_iters and args.dry_run_iters > 0 else 10 ) for cur_backend in backends: - if cur_backend in autotune_supported_backends: - if args.verbose >= 1: - print(f"[INFO] Autotune warmup for mm_fp4: {warmup_iters} iters") - with autotune(True): - for _ in range(warmup_iters): - run_backend(cur_backend) + if args.verbose >= 1: + print(f"[INFO] Autotune warmup for mm_fp4: {warmup_iters} iters") + with autotune(True): + for _ in range(warmup_iters): + run_backend(cur_backend) # Storage for timing results and outputs backend_times = {backend: [] for backend in backends} diff --git a/benchmarks/routines/moe.py b/benchmarks/routines/moe.py index 6af3425c73..8f26bdb8f7 100644 --- a/benchmarks/routines/moe.py +++ b/benchmarks/routines/moe.py @@ -116,13 +116,6 @@ def parse_moe_args(line, parser): default=None, help="Number of experts handled by this device. Defaults to num_experts.", ) - parser.add_argument( - "--tile_tokens_dim", - type=int, - required=False, - default=8, - help="Tile dimension for tokens.", - ) parser.add_argument( "--routing_method", type=str, @@ -560,7 +553,6 @@ def testTrtllmFp4BlockScaleMoe(args): ) local_expert_offset = args.local_expert_offset local_num_experts = args.local_num_experts or num_experts - tile_tokens_dim = args.tile_tokens_dim routing_method_type = args.routing_method_type use_shuffled_weight = args.use_shuffled_weight weight_layout = args.weight_layout @@ -705,7 +697,6 @@ def run_fp4_moe(): local_expert_offset=local_expert_offset, local_num_experts=local_num_experts, routed_scaling_factor=routed_scaling_factor, - tile_tokens_dim=tile_tokens_dim, routing_method_type=routing_method_type, gated_act_type=gated_act_type, do_finalize=True, @@ -780,7 +771,6 @@ def run_fp4_moe(): cur_res["routed_scaling_factor"] = routed_scaling_factor cur_res["local_expert_offset"] = local_expert_offset cur_res["local_num_experts"] = local_num_experts - cur_res["tile_tokens_dim"] = tile_tokens_dim cur_res["routing_method"] = args.routing_method cur_res["use_shuffled_weight"] = use_shuffled_weight cur_res["weight_layout"] = weight_layout @@ -1185,7 +1175,6 @@ def testTrtllmFp8BlockScaleMoe(args): ) local_expert_offset = args.local_expert_offset local_num_experts = args.local_num_experts or num_experts - tile_tokens_dim = args.tile_tokens_dim routing_method_type = args.routing_method_type use_shuffled_weight = args.use_shuffled_weight weight_layout = args.weight_layout @@ -1277,27 +1266,6 @@ def testTrtllmFp8BlockScaleMoe(args): print(f"[VVERBOSE] gemm1_weights_fp8.shape = {gemm1_weights_fp8.shape}") print(f"[VVERBOSE] gemm2_weights_fp8.shape = {gemm2_weights_fp8.shape}") - # Match test heuristic for tile_tokens_dim when using BlockMajorK - if use_shuffled_weight and weight_layout == WeightLayout.BlockMajorK: - - def _next_pow2(x: int) -> int: - x = max(1, x) - x -= 1 - x |= x >> 1 - x |= x >> 2 - x |= x >> 4 - x |= x >> 8 - x |= x >> 16 - return x + 1 - - tokens_per_expert = max(1, (num_tokens * top_k) // max(local_num_experts, 1)) - suggested_tile = min(max(_next_pow2(tokens_per_expert), 8), 64) - if suggested_tile != tile_tokens_dim and args.verbose >= 1: - print( - f"[INFO] Overriding tile_tokens_dim {tile_tokens_dim} -> {suggested_tile} for BlockMajorK" - ) - tile_tokens_dim = suggested_tile - def run_fp8_block_moe(): # Quantize hidden states to FP8 for block scale MOE hidden_states_fp8 = hidden_states.to(torch.float8_e4m3fn) @@ -1320,7 +1288,6 @@ def run_fp8_block_moe(): local_expert_offset=local_expert_offset, local_num_experts=local_num_experts, routed_scaling_factor=routed_scaling_factor, - tile_tokens_dim=tile_tokens_dim, routing_method_type=routing_method_type, use_shuffled_weight=use_shuffled_weight, weight_layout=weight_layout, @@ -1381,7 +1348,6 @@ def run_fp8_block_moe(): cur_res["routed_scaling_factor"] = routed_scaling_factor cur_res["local_expert_offset"] = local_expert_offset cur_res["local_num_experts"] = local_num_experts - cur_res["tile_tokens_dim"] = tile_tokens_dim cur_res["routing_method"] = args.routing_method cur_res["use_shuffled_weight"] = use_shuffled_weight cur_res["weight_layout"] = weight_layout @@ -1448,7 +1414,6 @@ def testTrtllmFp8PerTensorScaleMoe(args): ) local_expert_offset = args.local_expert_offset local_num_experts = args.local_num_experts or num_experts - tile_tokens_dim = args.tile_tokens_dim routing_method_type = args.routing_method_type use_routing_scales_on_input = args.use_routing_scales_on_input is_cuda_graph_compatible = not args.no_cuda_graph @@ -1527,7 +1492,6 @@ def run_fp8_per_tensor_moe(): local_num_experts=local_num_experts, routed_scaling_factor=routed_scaling_factor, use_routing_scales_on_input=use_routing_scales_on_input, - tile_tokens_dim=tile_tokens_dim, routing_method_type=routing_method_type, ) @@ -1585,7 +1549,6 @@ def run_fp8_per_tensor_moe(): cur_res["routed_scaling_factor"] = routed_scaling_factor cur_res["local_expert_offset"] = local_expert_offset cur_res["local_num_experts"] = local_num_experts - cur_res["tile_tokens_dim"] = tile_tokens_dim cur_res["routing_method"] = args.routing_method cur_res["use_routing_bias"] = args.use_routing_bias cur_res["use_routing_scales_on_input"] = use_routing_scales_on_input diff --git a/benchmarks/samples/sample_testlist_output.csv b/benchmarks/samples/sample_testlist_output.csv index d856d37ab0..b07c523ecb 100644 --- a/benchmarks/samples/sample_testlist_output.csv +++ b/benchmarks/samples/sample_testlist_output.csv @@ -1,4 +1,4 @@ -routine,median_time,std_time,tflops,tb_per_sec,backend,page_size,batch_size,s_qo,s_kv,num_qo_heads,num_kv_heads,head_dim_qk,head_dim_vo,head_dim_ckv,head_dim_kpe,causal,q_dtype,kv_dtype,avg_actual_seq_len,random_actual_seq_len,m,n,k,group_size,tile_size,scale_major_mode,out_dtype,mma_sm,use_128x4_sf_layout,use_nvfp4,num_tokens,hidden_size,intermediate_size,num_experts,top_k,n_group,topk_group,routed_scaling_factor,local_expert_offset,local_num_experts,tile_tokens_dim,routing_method,use_shuffled_weight,weight_layout,use_routing_bias,use_routing_scales_on_input,input_dtype,weight_dtype,gated_act,cutlass_variant,quantized_input,tp_size,tp_rank,ep_size,ep_rank,refcheck,no_cuda_graph,use_cupti,allow_output_mismatch,random_seed,case_tag,generate_repro_command,repro_command +routine,median_time,std_time,tflops,tb_per_sec,backend,page_size,batch_size,s_qo,s_kv,num_qo_heads,num_kv_heads,head_dim_qk,head_dim_vo,head_dim_ckv,head_dim_kpe,causal,q_dtype,kv_dtype,avg_actual_seq_len,random_actual_seq_len,m,n,k,group_size,tile_size,scale_major_mode,out_dtype,mma_sm,use_128x4_sf_layout,use_nvfp4,num_tokens,hidden_size,intermediate_size,num_experts,top_k,n_group,topk_group,routed_scaling_factor,local_expert_offset,local_num_experts,routing_method,use_shuffled_weight,weight_layout,use_routing_bias,use_routing_scales_on_input,input_dtype,weight_dtype,gated_act,cutlass_variant,quantized_input,tp_size,tp_rank,ep_size,ep_rank,refcheck,no_cuda_graph,use_cupti,allow_output_mismatch,random_seed,case_tag,generate_repro_command,repro_command BatchPrefillWithPagedKVCacheWrapper,0.01244799979031086,0.0009464459008260536,13.963516944729905,0.3050282827732261,fa2,16,1,1024,1024,64,8,128,128,,,True,torch.bfloat16,torch.bfloat16,103,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,True,42,Llama-3.1-70B,True,python3 flashinfer_benchmark.py --routine BatchPrefillWithPagedKVCacheWrapper --backends fa2 fa3 cudnn trtllm-gen --page_size 16 --batch_size 1 --s_qo 1024 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --causal --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag Llama-3.1-70B BatchPrefillWithPagedKVCacheWrapper,0.01839040070772171,0.00021363710731210026,9.45155349045863,0.20646597430613514,cudnn,16,1,1024,1024,64,8,128,128,,,True,torch.bfloat16,torch.bfloat16,103,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,True,42,Llama-3.1-70B,True,python3 flashinfer_benchmark.py --routine BatchPrefillWithPagedKVCacheWrapper --backends fa2 fa3 cudnn trtllm-gen --page_size 16 --batch_size 1 --s_qo 1024 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --causal --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag Llama-3.1-70B BatchPrefillWithPagedKVCacheWrapper,0.008396799862384795,5.550615129103214e-05,20.70048814413847,0.45219512936224815,trtllm-gen,16,1,1024,1024,64,8,128,128,,,True,torch.bfloat16,torch.bfloat16,103,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,True,42,Llama-3.1-70B,True,python3 flashinfer_benchmark.py --routine BatchPrefillWithPagedKVCacheWrapper --backends fa2 fa3 cudnn trtllm-gen --page_size 16 --batch_size 1 --s_qo 1024 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --causal --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag Llama-3.1-70B diff --git a/benchmarks/samples/sample_testlist_output.txt b/benchmarks/samples/sample_testlist_output.txt index 69a3961f87..d2c5cc4fa1 100644 --- a/benchmarks/samples/sample_testlist_output.txt +++ b/benchmarks/samples/sample_testlist_output.txt @@ -292,7 +292,7 @@ 2025-09-23 00:32:18,247 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends [PERF] cutlass_autotun:: median time 0.009 ms; std 0.000 ms; achieved tflops 6.372 TFLOPs/sec; achieved tb_per_sec 0.401 TB/sec [PERF] trtllm_autotune:: median time 0.011 ms; std 0.000 ms; achieved tflops 5.410 TFLOPs/sec; achieved tb_per_sec 0.340 TB/sec -[INFO] args = Namespace(routine='trtllm_fp4_block_scale_moe', no_cuda_graph=False, use_cupti=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='trtllm_moe_sample', generate_repro_command=True, repro_command='', num_tokens=1024, hidden_size=1024, intermediate_size=1024, num_experts=256, top_k=8, n_group=8, topk_group=4, routed_scaling_factor=2.5, local_expert_offset=0, local_num_experts=None, tile_tokens_dim=8, routing_method='deepseek_v3', use_shuffled_weight=True, weight_layout=0, use_routing_bias=True, use_routing_scales_on_input=False, input_dtype='bfloat16', weight_dtype='bfloat16', gated_act='swiglu', autotune=False, cutlass_variant='base', quantized_input=False, tp_size=1, tp_rank=0, ep_size=1, ep_rank=0, routing_method_type=2, gated_act_type=0) +[INFO] args = Namespace(routine='trtllm_fp4_block_scale_moe', no_cuda_graph=False, use_cupti=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='trtllm_moe_sample', generate_repro_command=True, repro_command='', num_tokens=1024, hidden_size=1024, intermediate_size=1024, num_experts=256, top_k=8, n_group=8, topk_group=4, routed_scaling_factor=2.5, local_expert_offset=0, local_num_experts=None, routing_method='deepseek_v3', use_shuffled_weight=True, weight_layout=0, use_routing_bias=True, use_routing_scales_on_input=False, input_dtype='bfloat16', weight_dtype='bfloat16', gated_act='swiglu', autotune=False, cutlass_variant='base', quantized_input=False, tp_size=1, tp_rank=0, ep_size=1, ep_rank=0, routing_method_type=2, gated_act_type=0) [INFO] Running testTrtllmFp4BlockScaleMoe [INFO] FlashInfer version: 0.3.1 [VVERBOSE] gpu_name = 'NVIDIA_B200' @@ -303,7 +303,7 @@ [VVERBOSE] gemm1_weights_fp4.shape = torch.Size([256, 2048, 512]) [VVERBOSE] gemm2_weights_fp4.shape = torch.Size([256, 1024, 512]) [PERF] trtllm :: median time 0.224 ms; std 0.000 ms; achieved tflops 230.555 TFLOPs/sec; achieved tb_per_sec 1.818 TB/sec -[INFO] args = Namespace(routine='trtllm_fp4_block_scale_moe', no_cuda_graph=False, use_cupti=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='trtllm_moe_sample', generate_repro_command=True, repro_command='', num_tokens=1024, hidden_size=1024, intermediate_size=1024, num_experts=128, top_k=8, n_group=None, topk_group=None, routed_scaling_factor=2.5, local_expert_offset=0, local_num_experts=None, tile_tokens_dim=8, routing_method='renormalize_naive', use_shuffled_weight=True, weight_layout=0, use_routing_bias=False, use_routing_scales_on_input=False, input_dtype='bfloat16', weight_dtype='bfloat16', gated_act='swiglu', autotune=False, cutlass_variant='base', quantized_input=False, tp_size=1, tp_rank=0, ep_size=1, ep_rank=0, routing_method_type=4, gated_act_type=0) +[INFO] args = Namespace(routine='trtllm_fp4_block_scale_moe', no_cuda_graph=False, use_cupti=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='trtllm_moe_sample', generate_repro_command=True, repro_command='', num_tokens=1024, hidden_size=1024, intermediate_size=1024, num_experts=128, top_k=8, n_group=None, topk_group=None, routed_scaling_factor=2.5, local_expert_offset=0, local_num_experts=None, routing_method='renormalize_naive', use_shuffled_weight=True, weight_layout=0, use_routing_bias=False, use_routing_scales_on_input=False, input_dtype='bfloat16', weight_dtype='bfloat16', gated_act='swiglu', autotune=False, cutlass_variant='base', quantized_input=False, tp_size=1, tp_rank=0, ep_size=1, ep_rank=0, routing_method_type=4, gated_act_type=0) [INFO] Running testTrtllmFp4BlockScaleMoe [INFO] FlashInfer version: 0.3.1 [VVERBOSE] gpu_name = 'NVIDIA_B200' @@ -314,7 +314,7 @@ [VVERBOSE] gemm1_weights_fp4.shape = torch.Size([128, 2048, 512]) [VVERBOSE] gemm2_weights_fp4.shape = torch.Size([128, 1024, 512]) [PERF] trtllm :: median time 0.226 ms; std 0.000 ms; achieved tflops 227.846 TFLOPs/sec; achieved tb_per_sec 0.903 TB/sec -[INFO] args = Namespace(routine='trtllm_fp8_block_scale_moe', no_cuda_graph=False, use_cupti=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='trtllm_moe_sample', generate_repro_command=True, repro_command='', num_tokens=1024, hidden_size=1024, intermediate_size=1024, num_experts=256, top_k=8, n_group=8, topk_group=4, routed_scaling_factor=2.5, local_expert_offset=0, local_num_experts=None, tile_tokens_dim=8, routing_method='deepseek_v3', use_shuffled_weight=True, weight_layout=0, use_routing_bias=True, use_routing_scales_on_input=False, input_dtype='bfloat16', weight_dtype='bfloat16', gated_act='swiglu', autotune=False, cutlass_variant='base', quantized_input=False, tp_size=1, tp_rank=0, ep_size=1, ep_rank=0, routing_method_type=2, gated_act_type=0) +[INFO] args = Namespace(routine='trtllm_fp8_block_scale_moe', no_cuda_graph=False, use_cupti=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='trtllm_moe_sample', generate_repro_command=True, repro_command='', num_tokens=1024, hidden_size=1024, intermediate_size=1024, num_experts=256, top_k=8, n_group=8, topk_group=4, routed_scaling_factor=2.5, local_expert_offset=0, local_num_experts=None, routing_method='deepseek_v3', use_shuffled_weight=True, weight_layout=0, use_routing_bias=True, use_routing_scales_on_input=False, input_dtype='bfloat16', weight_dtype='bfloat16', gated_act='swiglu', autotune=False, cutlass_variant='base', quantized_input=False, tp_size=1, tp_rank=0, ep_size=1, ep_rank=0, routing_method_type=2, gated_act_type=0) [INFO] Running testTrtllmFp8BlockScaleMoe [INFO] FlashInfer version: 0.3.1 [VVERBOSE] gpu_name = 'NVIDIA_B200' @@ -325,7 +325,7 @@ [VVERBOSE] gemm1_weights_fp8.shape = torch.Size([256, 2048, 1024]) [VVERBOSE] gemm2_weights_fp8.shape = torch.Size([256, 1024, 1024]) [PERF] trtllm :: median time 0.557 ms; std 0.000 ms; achieved tflops 92.607 TFLOPs/sec; achieved tb_per_sec 1.455 TB/sec -[INFO] args = Namespace(routine='trtllm_fp8_per_tensor_scale_moe', no_cuda_graph=False, use_cupti=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='trtllm_moe_sample', generate_repro_command=True, repro_command='', num_tokens=1024, hidden_size=1024, intermediate_size=1024, num_experts=128, top_k=1, n_group=None, topk_group=None, routed_scaling_factor=2.5, local_expert_offset=0, local_num_experts=None, tile_tokens_dim=8, routing_method='llama4', use_shuffled_weight=False, weight_layout=0, use_routing_bias=True, use_routing_scales_on_input=True, input_dtype='bfloat16', weight_dtype='bfloat16', gated_act='swiglu', autotune=False, cutlass_variant='base', quantized_input=False, tp_size=1, tp_rank=0, ep_size=1, ep_rank=0, routing_method_type=3, gated_act_type=0) +[INFO] args = Namespace(routine='trtllm_fp8_per_tensor_scale_moe', no_cuda_graph=False, use_cupti=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='trtllm_moe_sample', generate_repro_command=True, repro_command='', num_tokens=1024, hidden_size=1024, intermediate_size=1024, num_experts=128, top_k=1, n_group=None, topk_group=None, routed_scaling_factor=2.5, local_expert_offset=0, local_num_experts=None, routing_method='llama4', use_shuffled_weight=False, weight_layout=0, use_routing_bias=True, use_routing_scales_on_input=True, input_dtype='bfloat16', weight_dtype='bfloat16', gated_act='swiglu', autotune=False, cutlass_variant='base', quantized_input=False, tp_size=1, tp_rank=0, ep_size=1, ep_rank=0, routing_method_type=3, gated_act_type=0) [INFO] Running testTrtllmFp8PerTensorScaleMoe [INFO] FlashInfer version: 0.3.1 [VVERBOSE] gpu_name = 'NVIDIA_B200' @@ -336,7 +336,7 @@ [VVERBOSE] gemm1_weights_fp8.shape = torch.Size([128, 2048, 1024]) [VVERBOSE] gemm2_weights_fp8.shape = torch.Size([128, 1024, 1024]) [PERF] trtllm :: median time 0.123 ms; std 0.000 ms; achieved tflops 52.340 TFLOPs/sec; achieved tb_per_sec 3.299 TB/sec -[INFO] args = Namespace(routine='trtllm_fp8_block_scale_moe', no_cuda_graph=False, use_cupti=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='trtllm_moe_sample', generate_repro_command=True, repro_command='', num_tokens=1024, hidden_size=1024, intermediate_size=1024, num_experts=128, top_k=1, n_group=None, topk_group=None, routed_scaling_factor=2.5, local_expert_offset=0, local_num_experts=None, tile_tokens_dim=8, routing_method='renormalize', use_shuffled_weight=True, weight_layout=0, use_routing_bias=False, use_routing_scales_on_input=False, input_dtype='bfloat16', weight_dtype='bfloat16', gated_act='swiglu', autotune=False, cutlass_variant='base', quantized_input=False, tp_size=1, tp_rank=0, ep_size=1, ep_rank=0, routing_method_type=1, gated_act_type=0) +[INFO] args = Namespace(routine='trtllm_fp8_block_scale_moe', no_cuda_graph=False, use_cupti=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='trtllm_moe_sample', generate_repro_command=True, repro_command='', num_tokens=1024, hidden_size=1024, intermediate_size=1024, num_experts=128, top_k=1, n_group=None, topk_group=None, routed_scaling_factor=2.5, local_expert_offset=0, local_num_experts=None, routing_method='renormalize', use_shuffled_weight=True, weight_layout=0, use_routing_bias=False, use_routing_scales_on_input=False, input_dtype='bfloat16', weight_dtype='bfloat16', gated_act='swiglu', autotune=False, cutlass_variant='base', quantized_input=False, tp_size=1, tp_rank=0, ep_size=1, ep_rank=0, routing_method_type=1, gated_act_type=0) [INFO] Running testTrtllmFp8BlockScaleMoe [INFO] FlashInfer version: 0.3.1 [VVERBOSE] gpu_name = 'NVIDIA_B200' @@ -347,7 +347,7 @@ [VVERBOSE] gemm1_weights_fp8.shape = torch.Size([128, 2048, 1024]) [VVERBOSE] gemm2_weights_fp8.shape = torch.Size([128, 1024, 1024]) [PERF] trtllm :: median time 0.109 ms; std 0.000 ms; achieved tflops 59.297 TFLOPs/sec; achieved tb_per_sec 3.740 TB/sec -[INFO] args = Namespace(routine='cutlass_fused_moe', no_cuda_graph=False, use_cupti=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='cutlass_moe_base', generate_repro_command=True, repro_command='', num_tokens=32, hidden_size=128, intermediate_size=128, num_experts=2, top_k=2, n_group=None, topk_group=None, routed_scaling_factor=2.5, local_expert_offset=0, local_num_experts=None, tile_tokens_dim=8, routing_method='deepseek_v3', use_shuffled_weight=False, weight_layout=0, use_routing_bias=False, use_routing_scales_on_input=False, input_dtype='float16', weight_dtype='bfloat16', gated_act='swiglu', autotune=False, cutlass_variant='base', quantized_input=False, tp_size=1, tp_rank=0, ep_size=1, ep_rank=0, routing_method_type=2, gated_act_type=0) +[INFO] args = Namespace(routine='cutlass_fused_moe', no_cuda_graph=False, use_cupti=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='cutlass_moe_base', generate_repro_command=True, repro_command='', num_tokens=32, hidden_size=128, intermediate_size=128, num_experts=2, top_k=2, n_group=None, topk_group=None, routed_scaling_factor=2.5, local_expert_offset=0, local_num_experts=None, routing_method='deepseek_v3', use_shuffled_weight=False, weight_layout=0, use_routing_bias=False, use_routing_scales_on_input=False, input_dtype='float16', weight_dtype='bfloat16', gated_act='swiglu', autotune=False, cutlass_variant='base', quantized_input=False, tp_size=1, tp_rank=0, ep_size=1, ep_rank=0, routing_method_type=2, gated_act_type=0) [INFO] Running testCutlassFusedMoe [INFO] FlashInfer version: 0.3.1 [VVERBOSE] gpu_name = 'NVIDIA_B200' diff --git a/ci/docker-tags.yml b/ci/docker-tags.yml index ba3a947bc6..36fe4a6920 100644 --- a/ci/docker-tags.yml +++ b/ci/docker-tags.yml @@ -1,4 +1,4 @@ -flashinfer/flashinfer-ci-cu126: 20251024-0e48aaf -flashinfer/flashinfer-ci-cu128: 20251024-0e48aaf -flashinfer/flashinfer-ci-cu129: 20251024-0e48aaf -flashinfer/flashinfer-ci-cu130: 20251024-0e48aaf +flashinfer/flashinfer-ci-cu126: 20251206-185d63a +flashinfer/flashinfer-ci-cu128: 20251206-185d63a +flashinfer/flashinfer-ci-cu129: 20251206-185d63a +flashinfer/flashinfer-ci-cu130: 20251206-185d63a diff --git a/csrc/batch_attention.cu b/csrc/batch_attention.cu index a3d36b7981..b37a9a6a18 100644 --- a/csrc/batch_attention.cu +++ b/csrc/batch_attention.cu @@ -48,7 +48,7 @@ Array BatchPagedAttentionPlan(TensorView float_workspace_buffer, HolisticPlanInfo<2> plan_info; - cudaSetDevice(float_workspace_buffer.device().device_id); + ffi::CUDADeviceGuard device_guard(float_workspace_buffer.device().device_id); const cudaStream_t stream = get_stream(float_workspace_buffer.device()); cudaError_t status = TwoStageHolisticPlan( @@ -102,7 +102,7 @@ void BatchPagedAttentionRun(TensorView float_workspace_buffer, TensorView int_wo v_stride_n = v_cache.stride(2); } - cudaSetDevice(q.device().device_id); + ffi::CUDADeviceGuard device_guard(q.device().device_id); const cudaStream_t stream = get_stream(q.device()); DISPATCH_context( diff --git a/csrc/batch_decode.cu b/csrc/batch_decode.cu index c3ce1e2ecf..8cc31fbe01 100644 --- a/csrc/batch_decode.cu +++ b/csrc/batch_decode.cu @@ -42,6 +42,8 @@ Array BatchDecodeWithPagedKVCachePlan( int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t window_left, double logits_soft_cap, int64_t head_dim_qk, int64_t head_dim_vo, TensorView empty_q_data, TensorView empty_kv_data) { + CHECK_INPUT_TYPE(indptr, dl_int32); + size_t float_workspace_size_in_bytes = float_workspace_buffer.size(0) * get_element_size(float_workspace_buffer); size_t int_workspace_size_in_bytes = @@ -53,7 +55,7 @@ Array BatchDecodeWithPagedKVCachePlan( << "CUDA cores template only supports equal head dim for QK and VO, please use tensor " "cores template for different head dim"; - cudaSetDevice(float_workspace_buffer.device().device_id); + ffi::CUDADeviceGuard device_guard(float_workspace_buffer.device().device_id); const cudaStream_t stream = get_stream(float_workspace_buffer.device()); DISPATCH_context( DTypeQ, DTypeKV, DTypeO, IdType, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, @@ -86,6 +88,10 @@ void BatchDecodeWithPagedKVCacheRun(TensorView float_workspace_buffer, TensorView o, Optional maybe_lse, int64_t kv_layout_code, int64_t window_left, bool enable_pdl ADDITIONAL_FUNC_PARAMS) { + CHECK_INPUT_TYPE(paged_kv_indptr, dl_int32); + CHECK_INPUT_TYPE(paged_kv_indices, dl_int32); + CHECK_INPUT_TYPE(paged_kv_last_page_len, dl_int32); + DecodePlanInfo plan_info; plan_info.FromVector(std::vector(plan_info_vec.begin(), plan_info_vec.end())); QKVLayout kv_layout = static_cast(kv_layout_code); @@ -130,7 +136,7 @@ void BatchDecodeWithPagedKVCacheRun(TensorView float_workspace_buffer, } kv_cache_strides = k_strides.data(); - cudaSetDevice(q.device().device_id); + ffi::CUDADeviceGuard device_guard(q.device().device_id); const cudaStream_t stream = get_stream(q.device()); DISPATCH_context( diff --git a/csrc/batch_decode_mla_cute_sm80.cu b/csrc/batch_decode_mla_cute_sm80.cu index 5679076438..45b708018f 100644 --- a/csrc/batch_decode_mla_cute_sm80.cu +++ b/csrc/batch_decode_mla_cute_sm80.cu @@ -23,7 +23,7 @@ Array BatchDecodeWithPagedKVCachePlanMLA(ffi::TensorView float_workspac int_workspace_buffer.size(0) * get_element_size(int_workspace_buffer); DecodePlanInfo plan_info; - cudaSetDevice(float_workspace_buffer.device().device_id); + ffi::CUDADeviceGuard device_guard(float_workspace_buffer.device().device_id); const cudaStream_t stream = get_stream(float_workspace_buffer.device()); auto work_estimation_func = BatchDecodeWithPagedKVCacheWorkEstimationDispatchedMlaCuteSM80< @@ -103,7 +103,7 @@ void BatchDecodeWithPagedKVCacheRunMLA( } params.padded_batch_size = plan_info.padded_batch_size; - cudaSetDevice(paged_ckv_cache.device().device_id); + ffi::CUDADeviceGuard device_guard(paged_ckv_cache.device().device_id); const cudaStream_t stream = get_stream(paged_ckv_cache.device()); cudaError_t status = BatchDecodeWithPagedKVCacheDispatchedMlaCuteSM80( diff --git a/csrc/batch_decode_mla_plan.cu b/csrc/batch_decode_mla_plan.cu index 7925a14f27..e409cde882 100644 --- a/csrc/batch_decode_mla_plan.cu +++ b/csrc/batch_decode_mla_plan.cu @@ -15,7 +15,9 @@ Array BatchDecodeWithPagedKVCachePlanMLA(TensorView float_workspace_buf TensorView indptr, int64_t batch_size, int64_t num_qo_heads, int64_t page_size, bool enable_cuda_graph) { - cudaSetDevice(float_workspace_buffer.device().device_id); + CHECK_INPUT_TYPE(indptr, dl_int32); + + ffi::CUDADeviceGuard device_guard(float_workspace_buffer.device().device_id); const cudaStream_t stream = get_stream(float_workspace_buffer.device()); size_t float_workspace_size_in_bytes = diff --git a/csrc/batch_decode_mla_run.cu b/csrc/batch_decode_mla_run.cu index 35d533b536..94b5e35e0b 100644 --- a/csrc/batch_decode_mla_run.cu +++ b/csrc/batch_decode_mla_run.cu @@ -17,6 +17,10 @@ void BatchDecodeWithPagedKVCacheRunMLA( TensorView paged_kv_last_page_len, TensorView o, double sm_scale, int64_t window_left, double logits_soft_cap, double rope_scale, double rope_theta, Optional maybe_lse, bool enable_pdl) { + CHECK_INPUT_TYPE(paged_kv_indptr, dl_int32); + CHECK_INPUT_TYPE(paged_kv_indices, dl_int32); + CHECK_INPUT_TYPE(paged_kv_last_page_len, dl_int32); + DecodePlanInfo plan_info; plan_info.FromVector(std::vector(plan_info_vec.begin(), plan_info_vec.end())); @@ -35,7 +39,7 @@ void BatchDecodeWithPagedKVCacheRunMLA( void* float_buffer = static_cast(float_workspace_buffer.data_ptr()); void* int_buffer = static_cast(int_workspace_buffer.data_ptr()); - cudaSetDevice(q_nope.device().device_id); + ffi::CUDADeviceGuard device_guard(q_nope.device().device_id); const cudaStream_t stream = get_stream(q_nope.device()); paged_kv_mla_t paged_kv( diff --git a/csrc/batch_mla_binding.cu b/csrc/batch_mla_binding.cu index 6822e28b93..b39192de6a 100644 --- a/csrc/batch_mla_binding.cu +++ b/csrc/batch_mla_binding.cu @@ -31,7 +31,8 @@ void BatchMLAPagedAttentionRun(TensorView float_workspace_buffer, TensorView int Array plan_info_vec, TensorView q_nope, TensorView q_pe, TensorView ckv_cache, TensorView kpe_cache, TensorView kv_indices, TensorView o, Optional maybe_lse, int64_t mask_mode_code, - int64_t num_heads, int64_t page_size, double sm_scale); + int64_t num_heads, int64_t page_size, double sm_scale, + bool return_lse_base_on_e); TVM_FFI_DLL_EXPORT_TYPED_FUNC(plan, BatchMLAPagedAttentionPlan); TVM_FFI_DLL_EXPORT_TYPED_FUNC(run, BatchMLAPagedAttentionRun); diff --git a/csrc/batch_mla_plan.cu b/csrc/batch_mla_plan.cu index 1f7176e452..f4e8bc4bda 100644 --- a/csrc/batch_mla_plan.cu +++ b/csrc/batch_mla_plan.cu @@ -29,6 +29,10 @@ Array BatchMLAPagedAttentionPlan(TensorView float_workspace_buffer, TensorView qo_indptr, TensorView kv_indptr, TensorView kv_len, int64_t num_heads, int64_t head_dim_o, bool causal) { + CHECK_INPUT_TYPE(qo_indptr, dl_int32); + CHECK_INPUT_TYPE(kv_indptr, dl_int32); + CHECK_INPUT_TYPE(kv_len, dl_int32); + size_t float_workspace_size_in_bytes = float_workspace_buffer.size(0) * get_element_size(float_workspace_buffer); size_t int_workspace_size_in_bytes = @@ -38,7 +42,7 @@ Array BatchMLAPagedAttentionPlan(TensorView float_workspace_buffer, int batch_size = kv_len.size(0); - cudaSetDevice(float_workspace_buffer.device().device_id); + ffi::CUDADeviceGuard device_guard(float_workspace_buffer.device().device_id); const cudaStream_t stream = get_stream(float_workspace_buffer.device()); cudaError_t status = diff --git a/csrc/batch_mla_run.cu b/csrc/batch_mla_run.cu index dfa2442f1b..3dc142d1f4 100644 --- a/csrc/batch_mla_run.cu +++ b/csrc/batch_mla_run.cu @@ -31,11 +31,14 @@ void BatchMLAPagedAttentionRun(TensorView float_workspace_buffer, TensorView int Array plan_info_vec, TensorView q_nope, TensorView q_pe, TensorView ckv_cache, TensorView kpe_cache, TensorView kv_indices, TensorView o, Optional maybe_lse, int64_t mask_mode_code, - int64_t num_heads, int64_t page_size, double sm_scale) { + int64_t num_heads, int64_t page_size, double sm_scale, + bool return_lse_base_on_e) { // q_nope: [n, num_heads, head_dim_ckv] // q_pe: [n, num_heads, head_dim_kpe] // ckv_cache: [num_pages, page_size, head_dim_ckv] // kpe_cache: [num_pages, page_size, head_dim_kpe] + CHECK_INPUT_TYPE(kv_indices, dl_int32); + MLAPlanInfo plan_info; plan_info.FromVector(std::vector(plan_info_vec.begin(), plan_info_vec.end())); @@ -55,7 +58,7 @@ void BatchMLAPagedAttentionRun(TensorView float_workspace_buffer, TensorView int unsigned int o_stride_n = o.stride(0); unsigned int o_stride_h = o.stride(1); - cudaSetDevice(q_nope.device().device_id); + ffi::CUDADeviceGuard device_guard(q_nope.device().device_id); const cudaStream_t stream = get_stream(q_nope.device()); DISPATCH_context( @@ -112,6 +115,7 @@ void BatchMLAPagedAttentionRun(TensorView float_workspace_buffer, TensorView int params.o_stride_h = o_stride_h; params.sm_scale = sm_scale; + params.return_lse_base_on_e = return_lse_base_on_e; cudaError_t status = mla::BatchMLAPagedAttention( params, plan_info.num_blks_x, plan_info.num_blks_y, stream); diff --git a/csrc/batch_mla_sm90_binding.cu b/csrc/batch_mla_sm90_binding.cu index 2e6cd1aa7d..f2af49766a 100644 --- a/csrc/batch_mla_sm90_binding.cu +++ b/csrc/batch_mla_sm90_binding.cu @@ -32,8 +32,8 @@ void BatchMLAPagedAttentionSM90Run(TensorView float_workspace_buffer, TensorView q_nope, TensorView q_pe, TensorView ckv_cache, TensorView kpe_cache, TensorView kv_indices, TensorView o, Optional maybe_lse, int64_t mask_mode_code, - int64_t num_heads, int64_t page_size, - double sm_scale ADDITIONAL_FUNC_PARAMS); + int64_t num_heads, int64_t page_size, double sm_scale, + bool return_lse_base_on_e ADDITIONAL_FUNC_PARAMS); TVM_FFI_DLL_EXPORT_TYPED_FUNC(plan, BatchMLAPagedAttentionSM90Plan); TVM_FFI_DLL_EXPORT_TYPED_FUNC(run, BatchMLAPagedAttentionSM90Run); diff --git a/csrc/batch_mla_sm90_plan.cu b/csrc/batch_mla_sm90_plan.cu index d297ebab90..e51932e64b 100644 --- a/csrc/batch_mla_sm90_plan.cu +++ b/csrc/batch_mla_sm90_plan.cu @@ -38,7 +38,7 @@ Array BatchMLAPagedAttentionSM90Plan(TensorView float_workspace_buffer, int batch_size = kv_len.size(0); - cudaSetDevice(float_workspace_buffer.device().device_id); + ffi::CUDADeviceGuard device_guard(float_workspace_buffer.device().device_id); const cudaStream_t stream = get_stream(float_workspace_buffer.device()); cudaError_t status = diff --git a/csrc/batch_mla_sm90_run.cu b/csrc/batch_mla_sm90_run.cu index 8d6d80c223..b47a7ff7dc 100644 --- a/csrc/batch_mla_sm90_run.cu +++ b/csrc/batch_mla_sm90_run.cu @@ -31,8 +31,8 @@ void BatchMLAPagedAttentionSM90Run(TensorView float_workspace_buffer, TensorView q_nope, TensorView q_pe, TensorView ckv_cache, TensorView kpe_cache, TensorView kv_indices, TensorView o, Optional maybe_lse, int64_t mask_mode_code, - int64_t num_heads, int64_t page_size, - double sm_scale ADDITIONAL_FUNC_PARAMS) { + int64_t num_heads, int64_t page_size, double sm_scale, + bool return_lse_base_on_e ADDITIONAL_FUNC_PARAMS) { // q_nope: [n, num_heads, head_dim_ckv] // q_pe: [n, num_heads, head_dim_kpe] // ckv_cache: [num_pages, page_size, head_dim_ckv] @@ -56,7 +56,7 @@ void BatchMLAPagedAttentionSM90Run(TensorView float_workspace_buffer, unsigned int o_stride_n = o.stride(0); unsigned int o_stride_h = o.stride(1); - cudaSetDevice(q_nope.device().device_id); + ffi::CUDADeviceGuard device_guard(q_nope.device().device_id); const cudaStream_t stream = get_stream(q_nope.device()); DISPATCH_context( @@ -111,6 +111,7 @@ void BatchMLAPagedAttentionSM90Run(TensorView float_workspace_buffer, params.kpe_stride_n = kpe_stride_n; params.o_stride_n = o_stride_n; params.o_stride_h = o_stride_h; + params.return_lse_base_on_e = return_lse_base_on_e; ADDITIONAL_PARAMS_SETTER diff --git a/csrc/batch_pod.cu b/csrc/batch_pod.cu new file mode 100644 index 0000000000..98ff9d83da --- /dev/null +++ b/csrc/batch_pod.cu @@ -0,0 +1,350 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +#include "batch_pod_config.inc" +#include "tvm_ffi_utils.h" + +namespace flashinfer { +template +cudaError_t BatchPODWithKVCacheTensorDispatched(PrefillParams prefill_params, + typename PrefillParams::DTypeO* tmp_v_p, + float* tmp_s_p, DecodeParams decode_params, + typename DecodeParams::DTypeO* tmp_v_d, + float* tmp_s_d, bool enable_pdl, + cudaStream_t stream, int* sm_aware_sched); + +} // namespace flashinfer + +using namespace flashinfer; + +using tvm::ffi::Array; +using tvm::ffi::Optional; + +void batch_pod_with_kv_cache_tensor( + // Prefill params + TensorView float_workspace_buffer_p, TensorView int_workspace_buffer_p, + Array plan_info_vec_p, TensorView q_p, TensorView paged_k_cache_p, + TensorView paged_v_cache_p, TensorView qo_indptr_p, TensorView paged_kv_indptr_p, + TensorView paged_kv_indices_p, TensorView paged_kv_last_page_len_p, TensorView o_p, + Optional maybe_lse_p, int64_t mask_mode_code_p, int64_t layout_p, + int64_t window_left_p, Optional maybe_custom_mask_p, + Optional maybe_mask_indptr_p, Optional maybe_alibi_slopes_p, + double logits_soft_cap_p, double sm_scale_p, double rope_rcp_scale_p, double rope_rcp_theta_p, + // Decode params + TensorView float_workspace_buffer_d, TensorView int_workspace_buffer_d, + Array plan_info_vec_d, TensorView q_d, TensorView paged_k_cache_d, + TensorView paged_v_cache_d, TensorView qo_indptr_d, TensorView paged_kv_indptr_d, + TensorView paged_kv_indices_d, TensorView paged_kv_last_page_len_d, TensorView o_d, + Optional maybe_lse_d, int64_t mask_mode_code_d, int64_t layout_d, + int64_t window_left_d, Optional maybe_custom_mask_d, + Optional maybe_mask_indptr_d, Optional maybe_alibi_slopes_d, + double logits_soft_cap_d, double sm_scale_d, double rope_rcp_scale_d, double rope_rcp_theta_d, + bool enable_pdl, TensorView sm_aware_sched) { + // Prefill setup + PrefillPlanInfo plan_info_p; + plan_info_p.FromVector(std::vector(plan_info_vec_p.begin(), plan_info_vec_p.end())); + QKVLayout kv_layout_p = static_cast(layout_p); + int64_t batch_size_p = paged_kv_indptr_p.size(0) - 1; + int64_t num_qo_heads = q_p.size(1); + + int64_t num_kv_heads_p, page_size_p; + uint32_t head_dim_qk_p = q_p.size(2); + if (kv_layout_p == QKVLayout::kHND) { + num_kv_heads_p = paged_k_cache_p.size(1); + page_size_p = paged_k_cache_p.size(2); + } else { + page_size_p = paged_k_cache_p.size(1); + num_kv_heads_p = paged_k_cache_p.size(2); + } + + if (maybe_lse_p.has_value()) { + const auto& lse = maybe_lse_p.value(); + TVM_FFI_ICHECK_EQ(lse.size(0), q_p.size(0)); + TVM_FFI_ICHECK_EQ(lse.size(1), q_p.size(1)); + } + + void* float_buffer_ptr_p = static_cast(float_workspace_buffer_p.data_ptr()); + void* int_buffer_ptr_p = static_cast(int_workspace_buffer_p.data_ptr()); + + const MaskMode mask_mode_p = static_cast(mask_mode_code_p); + + // get q_stride_n and q_stride_h + const auto q_stride_n_p = q_p.stride(0); + const auto q_stride_h_p = q_p.stride(1); + + // get kv_cache_strides + const int64_t* kv_cache_strides_p = nullptr; + auto k_strides_p = paged_k_cache_p.strides(); + auto v_strides_p = paged_v_cache_p.strides(); + TVM_FFI_ICHECK_EQ(k_strides_p.size(), v_strides_p.size()); + for (int i = 0; i < k_strides_p.size(); ++i) { + TVM_FFI_ICHECK_EQ(k_strides_p[i], v_strides_p[i]); + } + kv_cache_strides_p = k_strides_p.data(); + + ffi::CUDADeviceGuard device_guard(float_workspace_buffer_p.device().device_id); + const cudaStream_t stream = get_stream(float_workspace_buffer_p.device()); + + // Decode setup (TensorView decode = batched prefill) + PrefillPlanInfo plan_info_d; + plan_info_d.FromVector(std::vector(plan_info_vec_d.begin(), plan_info_vec_d.end())); + QKVLayout kv_layout_d = static_cast(layout_d); + int64_t batch_size_d = paged_kv_indptr_d.size(0) - 1; + int64_t num_qo_heads_d = q_d.size(1); + + TVM_FFI_ICHECK_EQ(num_qo_heads, num_qo_heads_d) + << "POD currently requires same # Query heads for prefill and decode"; + + int64_t num_kv_heads_d, page_size_d; + uint32_t head_dim_qk_d = q_d.size(2); + if (kv_layout_d == QKVLayout::kHND) { + num_kv_heads_d = paged_k_cache_d.size(1); + page_size_d = paged_k_cache_d.size(2); + } else { + page_size_d = paged_k_cache_d.size(1); + num_kv_heads_d = paged_k_cache_d.size(2); + } + TVM_FFI_ICHECK_EQ(num_kv_heads_p, num_kv_heads_d) + << "POD currently requires same # KV heads for prefill and decode; Prefill: " + << num_kv_heads_p << ", Decode: " << num_kv_heads_d; + + if (maybe_lse_d.has_value()) { + const auto& lse = maybe_lse_d.value(); + TVM_FFI_ICHECK_EQ(lse.size(0), q_d.size(0)); + TVM_FFI_ICHECK_EQ(lse.size(1), q_d.size(1)); + } + + void* float_buffer_ptr_d = static_cast(float_workspace_buffer_d.data_ptr()); + void* int_buffer_ptr_d = static_cast(int_workspace_buffer_d.data_ptr()); + + const MaskMode mask_mode_d = static_cast(mask_mode_code_d); + + // get q_stride_n and q_stride_h + const auto q_stride_n_d = q_d.stride(0); + const auto q_stride_h_d = q_d.stride(1); + + // get kv_cache_strides + const int64_t* kv_cache_strides_d = nullptr; + auto k_strides_d = paged_k_cache_d.strides(); + auto v_strides_d = paged_v_cache_d.strides(); + TVM_FFI_ICHECK_EQ(k_strides_d.size(), v_strides_d.size()); + for (int i = 0; i < k_strides_d.size(); ++i) { + TVM_FFI_ICHECK_EQ(k_strides_d[i], v_strides_d[i]); + } + kv_cache_strides_d = k_strides_d.data(); + + // Already handled by prefill + // ffi::CUDADeviceGuard device_guard(float_workspace_buffer_d.device().device_id); + // const cudaStream_t stream = get_stream(float_workspace_buffer_d.device()); + + DISPATCH_context( + MASK_MODE_P, MASK_MODE_D, DTypeQ, DTypeKV, HEAD_DIM_QK, USE_SLIDING_WINDOW_P, + USE_SLIDING_WINDOW_D, USE_LOGITS_SOFT_CAP, [&] { + PrefillParams prefill_params; + DTypeO* tmp_v_p = nullptr; + float* tmp_s_p = nullptr; + { + PrefillParams& params = prefill_params; + params.q = static_cast(q_p.data_ptr()); + paged_kv_t paged_kv( + num_kv_heads_p, page_size_p, HEAD_DIM_VO, batch_size_p, kv_layout_p, + static_cast(paged_k_cache_p.data_ptr()), + static_cast(paged_v_cache_p.data_ptr()), kv_cache_strides_p, + static_cast(paged_kv_indices_p.data_ptr()), + static_cast(paged_kv_indptr_p.data_ptr()), + static_cast(paged_kv_last_page_len_p.data_ptr())); + params.paged_kv = paged_kv; + params.q_indptr = static_cast(qo_indptr_p.data_ptr()); + params.o = static_cast(o_p.data_ptr()); + + params.lse = maybe_lse_p.has_value() ? static_cast(maybe_lse_p.value().data_ptr()) + : nullptr; + params.num_qo_heads = num_qo_heads; + params.group_size = uint_fastdiv(num_qo_heads / paged_kv.num_heads); + params.q_stride_n = q_stride_n_p; + params.q_stride_h = q_stride_h_p; + params.window_left = window_left_p; + + params.request_indices = nullptr; + params.qo_tile_indices = nullptr; + params.kv_tile_indices = nullptr; + params.merge_indptr = nullptr; + params.o_indptr = nullptr; + params.kv_chunk_size_ptr = nullptr; + params.block_valid_mask = nullptr; + params.total_num_rows = nullptr; + params.max_total_num_rows = 0; + params.padded_batch_size = 0; + params.partition_kv = false; + + params.maybe_mask_indptr = + maybe_mask_indptr_p.has_value() + ? static_cast(maybe_mask_indptr_p.value().data_ptr()) + : nullptr; + params.maybe_alibi_slopes = + maybe_alibi_slopes_p.has_value() + ? static_cast(maybe_alibi_slopes_p.value().data_ptr()) + : nullptr; + params.logits_soft_cap = logits_soft_cap_p; + params.sm_scale = sm_scale_p; + params.rope_rcp_scale = rope_rcp_scale_p; + params.rope_rcp_theta = rope_rcp_theta_p; + + params.request_indices = + GetPtrFromBaseOffset(int_buffer_ptr_p, plan_info_p.request_indices_offset); + params.qo_tile_indices = + GetPtrFromBaseOffset(int_buffer_ptr_p, plan_info_p.qo_tile_indices_offset); + params.kv_tile_indices = + GetPtrFromBaseOffset(int_buffer_ptr_p, plan_info_p.kv_tile_indices_offset); + params.o_indptr = + GetPtrFromBaseOffset(int_buffer_ptr_p, plan_info_p.o_indptr_offset); + params.kv_chunk_size_ptr = + GetPtrFromBaseOffset(int_buffer_ptr_p, plan_info_p.kv_chunk_size_ptr_offset); + if (plan_info_p.split_kv) { + params.merge_indptr = + GetPtrFromBaseOffset(int_buffer_ptr_p, plan_info_p.merge_indptr_offset); + tmp_v_p = GetPtrFromBaseOffset(float_buffer_ptr_p, plan_info_p.v_offset); + tmp_s_p = GetPtrFromBaseOffset(float_buffer_ptr_p, plan_info_p.s_offset); + if (plan_info_p.enable_cuda_graph) { + params.block_valid_mask = + GetPtrFromBaseOffset(int_buffer_ptr_p, plan_info_p.block_valid_mask_offset); + } + } + params.padded_batch_size = plan_info_p.padded_batch_size; + params.max_total_num_rows = plan_info_p.total_num_rows; + if (plan_info_p.enable_cuda_graph) { + params.total_num_rows = + GetPtrFromBaseOffset(int_buffer_ptr_p, plan_info_p.total_num_rows_offset); + } + } + + DecodeParams decode_params; + DTypeO* tmp_v_d = nullptr; + float* tmp_s_d = nullptr; + { + DecodeParams& params = decode_params; + params.q = static_cast(q_d.data_ptr()); + paged_kv_t paged_kv( + num_kv_heads_d, page_size_d, HEAD_DIM_VO, batch_size_d, kv_layout_d, + static_cast(paged_k_cache_d.data_ptr()), + static_cast(paged_v_cache_d.data_ptr()), kv_cache_strides_d, + static_cast(paged_kv_indices_d.data_ptr()), + static_cast(paged_kv_indptr_d.data_ptr()), + static_cast(paged_kv_last_page_len_d.data_ptr())); + params.paged_kv = paged_kv; + params.q_indptr = static_cast(qo_indptr_d.data_ptr()); + params.o = static_cast(o_d.data_ptr()); + + params.lse = maybe_lse_d.has_value() ? static_cast(maybe_lse_d.value().data_ptr()) + : nullptr; + params.num_qo_heads = num_qo_heads; + params.group_size = uint_fastdiv(num_qo_heads / paged_kv.num_heads); + params.q_stride_n = q_stride_n_d; + params.q_stride_h = q_stride_h_d; + params.window_left = window_left_d; + + params.request_indices = nullptr; + params.qo_tile_indices = nullptr; + params.kv_tile_indices = nullptr; + params.merge_indptr = nullptr; + params.o_indptr = nullptr; + params.kv_chunk_size_ptr = nullptr; + params.block_valid_mask = nullptr; + params.total_num_rows = nullptr; + params.max_total_num_rows = 0; + params.padded_batch_size = 0; + params.partition_kv = false; + + params.maybe_mask_indptr = + maybe_mask_indptr_d.has_value() + ? static_cast(maybe_mask_indptr_d.value().data_ptr()) + : nullptr; + params.maybe_alibi_slopes = + maybe_alibi_slopes_d.has_value() + ? static_cast(maybe_alibi_slopes_d.value().data_ptr()) + : nullptr; + params.logits_soft_cap = logits_soft_cap_d; + params.sm_scale = sm_scale_d; + params.rope_rcp_scale = rope_rcp_scale_d; + params.rope_rcp_theta = rope_rcp_theta_d; + + params.request_indices = + GetPtrFromBaseOffset(int_buffer_ptr_d, plan_info_d.request_indices_offset); + params.qo_tile_indices = + GetPtrFromBaseOffset(int_buffer_ptr_d, plan_info_d.qo_tile_indices_offset); + params.kv_tile_indices = + GetPtrFromBaseOffset(int_buffer_ptr_d, plan_info_d.kv_tile_indices_offset); + params.o_indptr = + GetPtrFromBaseOffset(int_buffer_ptr_d, plan_info_d.o_indptr_offset); + params.kv_chunk_size_ptr = + GetPtrFromBaseOffset(int_buffer_ptr_d, plan_info_d.kv_chunk_size_ptr_offset); + if (plan_info_d.split_kv) { + params.merge_indptr = + GetPtrFromBaseOffset(int_buffer_ptr_d, plan_info_d.merge_indptr_offset); + tmp_v_d = GetPtrFromBaseOffset(float_buffer_ptr_d, plan_info_d.v_offset); + tmp_s_d = GetPtrFromBaseOffset(float_buffer_ptr_d, plan_info_d.s_offset); + if (plan_info_d.enable_cuda_graph) { + params.block_valid_mask = + GetPtrFromBaseOffset(int_buffer_ptr_d, plan_info_d.block_valid_mask_offset); + } + } + params.padded_batch_size = plan_info_d.padded_batch_size; + params.max_total_num_rows = plan_info_d.total_num_rows; + if (plan_info_d.enable_cuda_graph) { + params.total_num_rows = + GetPtrFromBaseOffset(int_buffer_ptr_d, plan_info_d.total_num_rows_offset); + } + } + + constexpr bool use_custom_mask_p = MASK_MODE_P == MaskMode::kCustom; + using PrefillAttentionVariant = + DefaultAttention; + constexpr bool use_custom_mask_d = MASK_MODE_D == MaskMode::kCustom; + using DecodeAttentionVariant = + DefaultAttention; + + int dev_id = 0; + FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); + int num_sm = 0; + FLASHINFER_CUDA_CALL( + cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); + // SM-aware scheduling buffer uses num_sm + 2 entries + // num_sm entries for counters for each SM, and + // 2 entries for keeping track of blockIds for prefill and decode + assert( + sm_aware_sched.ndim() == 1 && sm_aware_sched.size(0) == num_sm + 2 && + "sm_aware_sched tensor has incorrect shape or type, should be (num_sm + 2,) of int32"); + DISPATCH_CTA_TILE_Q(plan_info_p.cta_tile_q, CTA_TILE_Q_P, { + constexpr size_t CTA_TILE_Q_D = 16; + cudaError_t status = flashinfer::BatchPODWithKVCacheTensorDispatched< + HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, USE_FP16_QK_REDUCTION, CTA_TILE_Q_P, + MASK_MODE_P, CTA_TILE_Q_D, MASK_MODE_D, PrefillAttentionVariant, + DecodeAttentionVariant>(prefill_params, tmp_v_p, tmp_s_p, decode_params, tmp_v_d, + tmp_s_d, enable_pdl, stream, + static_cast(sm_aware_sched.data_ptr())); + TVM_FFI_ICHECK(status == cudaSuccess) + << "BatchPODWithKVCache kernel launch failed, error: " << cudaGetErrorString(status); + return status; + }); + }); +} diff --git a/csrc/batch_pod_customize_config.jinja b/csrc/batch_pod_customize_config.jinja new file mode 100644 index 0000000000..9f27b42953 --- /dev/null +++ b/csrc/batch_pod_customize_config.jinja @@ -0,0 +1,43 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace flashinfer; + +using DTypeQ = {{ dtype_q }}; +using DTypeKV = {{ dtype_kv }}; +using DTypeO = {{ dtype_o }}; +using IdType = {{ idtype }}; +constexpr int HEAD_DIM_QK = {{ head_dim_qk }}; +constexpr int HEAD_DIM_VO = {{ head_dim_vo }}; +constexpr bool USE_FP16_QK_REDUCTION = {{ use_fp16_qk_reduction }}; + +constexpr auto USE_LOGITS_SOFT_CAP_P = {{ use_logits_soft_cap_p }}; +constexpr auto POS_ENCODING_MODE_P = {{ pos_encoding_mode_p }}; +constexpr auto USE_SLIDING_WINDOW_P = {{ use_sliding_window_p }}; + +constexpr auto USE_LOGITS_SOFT_CAP_D = {{ use_logits_soft_cap_d }}; +constexpr auto POS_ENCODING_MODE_D = {{ pos_encoding_mode_d }}; +constexpr auto USE_SLIDING_WINDOW_D = {{ use_sliding_window_d }}; + +constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; +constexpr bool USE_LOGITS_SOFT_CAP = false; + +using PrefillParams = BatchPrefillPagedParams; +using DecodeParams = BatchPrefillPagedParams; + +#define DISPATCH_context(MASK_MODE_P, MASK_MODE_D, DTypeQ, DTypeKV, HEAD_DIM_QK, \ + USE_SLIDING_WINDOW_P, USE_SLIDING_WINDOW_D, USE_LOGITS_SOFT_CAP, ...) \ + DISPATCH_MASK_MODE(mask_mode_p, MASK_MODE_P, { \ + DISPATCH_MASK_MODE(mask_mode_d, MASK_MODE_D, { \ + __VA_ARGS__(); \ + }); \ +}); diff --git a/csrc/batch_pod_jit_binding.cu b/csrc/batch_pod_jit_binding.cu new file mode 100644 index 0000000000..c7a8a5ea6b --- /dev/null +++ b/csrc/batch_pod_jit_binding.cu @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2023-2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "batch_pod_config.inc" +#include "tvm_ffi_utils.h" + +using tvm::ffi::Array; +using tvm::ffi::Optional; + +void batch_pod_with_kv_cache_tensor( + // Prefill params + TensorView float_workspace_buffer_p, TensorView int_workspace_buffer_p, + Array plan_info_vec_p, TensorView q_p, TensorView paged_k_cache_p, + TensorView paged_v_cache_p, TensorView qo_indptr_p, TensorView paged_kv_indptr_p, + TensorView paged_kv_indices_p, TensorView paged_kv_last_page_len_p, TensorView o_p, + Optional maybe_lse_p, int64_t mask_mode_code_p, int64_t layout_p, + int64_t window_left_p, Optional maybe_custom_mask_p, + Optional maybe_mask_indptr_p, Optional maybe_alibi_slopes_p, + double logits_soft_cap_p, double sm_scale_p, double rope_rcp_scale_p, double rope_rcp_theta_p, + // Decode params + TensorView float_workspace_buffer_d, TensorView int_workspace_buffer_d, + Array plan_info_vec_d, TensorView q_d, TensorView paged_k_cache_d, + TensorView paged_v_cache_d, TensorView qo_indptr_d, TensorView paged_kv_indptr_d, + TensorView paged_kv_indices_d, TensorView paged_kv_last_page_len_d, TensorView o_d, + Optional maybe_lse_d, int64_t mask_mode_code_d, int64_t layout_d, + int64_t window_left_d, Optional maybe_custom_mask_d, + Optional maybe_mask_indptr_d, Optional maybe_alibi_slopes_d, + double logits_soft_cap_d, double sm_scale_d, double rope_rcp_scale_d, double rope_rcp_theta_d, + bool enable_pdl, TensorView sm_aware_sched); + +// Batch-request prefill attention with KV-Cache operator +TVM_FFI_DLL_EXPORT_TYPED_FUNC(batch_pod_with_kv_cache_tensor, batch_pod_with_kv_cache_tensor); diff --git a/csrc/batch_pod_kernel_inst.jinja b/csrc/batch_pod_kernel_inst.jinja new file mode 100644 index 0000000000..cb2c39d32b --- /dev/null +++ b/csrc/batch_pod_kernel_inst.jinja @@ -0,0 +1,31 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "batch_pod_config.inc" + +using namespace flashinfer; + +namespace flashinfer { +constexpr auto use_custom_mask_p = {{ mask_mode_p }} == MaskMode::kCustom; +constexpr auto use_custom_mask_d = {{ mask_mode_d }} == MaskMode::kCustom; +// Not sure about the below declaration +constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; + +{% for cta_tile_q in [16, 64, 128] %} +template cudaError_t BatchPODWithKVCacheTensorDispatched< + {{ head_dim_qk }}, {{ head_dim_vo }}, POS_ENCODING_MODE, + {{ use_fp16_qk_reduction }}, /*CTA_TILE_Q_P=*/{{cta_tile_q}}, {{ mask_mode_p }}, + /*CTA_TILE_Q_D=*/16, {{ mask_mode_d }}, {{ variant_name_p }}, + {{ variant_name_d }}, PrefillParams, DecodeParams>( + PrefillParams prefill_params, {{ dtype_o }}* tmp_v_p, float *tmp_s_p, + DecodeParams decode_params, {{ dtype_o }}* tmp_v_d, float *tmp_s_d, + bool enable_pdl, cudaStream_t stream, int* sm_aware_sched); +{% endfor %} +} diff --git a/csrc/batch_prefill.cu b/csrc/batch_prefill.cu index 5d7182bdc5..6011ba2063 100644 --- a/csrc/batch_prefill.cu +++ b/csrc/batch_prefill.cu @@ -50,7 +50,7 @@ Array BatchPrefillWithKVCachePlan( TensorView kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk, int64_t head_dim_vo, bool causal, int64_t window_left, int64_t fixed_split_size, - bool disable_split_kv) { + bool disable_split_kv, int64_t num_colocated_ctas = 0) { size_t float_workspace_size_in_bytes = float_workspace_buffer.size(0) * get_element_size(float_workspace_buffer); size_t int_workspace_size_in_bytes = @@ -58,7 +58,7 @@ Array BatchPrefillWithKVCachePlan( PrefillPlanInfo plan_info; - cudaSetDevice(float_workspace_buffer.device().device_id); + ffi::CUDADeviceGuard device_guard(float_workspace_buffer.device().device_id); const cudaStream_t stream = get_stream(float_workspace_buffer.device()); cudaError_t status = PrefillPlan( float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes, @@ -66,7 +66,8 @@ Array BatchPrefillWithKVCachePlan( int_workspace_size_in_bytes, plan_info, static_cast(qo_indptr.data_ptr()), static_cast(kv_indptr.data_ptr()), total_num_rows, batch_size, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, page_size, enable_cuda_graph, - /*sizeof_dtype_o=*/2, window_left, fixed_split_size, disable_split_kv, stream); + /*sizeof_dtype_o=*/2, window_left, fixed_split_size, disable_split_kv, num_colocated_ctas, + stream); TVM_FFI_ICHECK(status == cudaSuccess) << "Failed to plan prefill with error: " << cudaGetErrorString(status); @@ -113,7 +114,7 @@ void BatchPrefillWithRaggedKVCacheRun(TensorView float_workspace_buffer, const MaskMode mask_mode = static_cast(mask_mode_code); - cudaSetDevice(float_workspace_buffer.device().device_id); + ffi::CUDADeviceGuard device_guard(float_workspace_buffer.device().device_id); const cudaStream_t stream = get_stream(float_workspace_buffer.device()); DISPATCH_context( @@ -246,7 +247,7 @@ void BatchPrefillWithPagedKVCacheRun(TensorView float_workspace_buffer, << "k/v strides differs at " << i; } - cudaSetDevice(float_workspace_buffer.device().device_id); + ffi::CUDADeviceGuard device_guard(float_workspace_buffer.device().device_id); const cudaStream_t stream = get_stream(float_workspace_buffer.device()); DISPATCH_context( diff --git a/csrc/batch_prefill_fp8_ragged_sm90_kernel_inst.jinja b/csrc/batch_prefill_fp8_ragged_sm90_kernel_inst.jinja index 8225edbb00..0b615e57e8 100644 --- a/csrc/batch_prefill_fp8_ragged_sm90_kernel_inst.jinja +++ b/csrc/batch_prefill_fp8_ragged_sm90_kernel_inst.jinja @@ -1 +1,15 @@ -// TODO: Not implemented yet +#include +#include "batch_prefill_sm90_config.inc" + +namespace flashinfer { + +{% for same_scheduler_for_all_heads in ["true", "false"] %} +template cudaError_t BatchFP8PrefillWithRaggedKVCacheDispatched + <{{ head_dim_qk }}, + {{ mask_mode }}, + /*USE_SLIDING_WINDOW=*/{{ use_sliding_window }}, + /*SAME_SCHEDULER_FOR_ALL_HEADS=*/{{ same_scheduler_for_all_heads }}, + {{ variant_name }}, RaggedParams>(RaggedParams& params, bool enable_pdl, cudaStream_t stream); +{% endfor %} + +} // namespace flashinfer diff --git a/csrc/batch_prefill_fp8_sm90.cu b/csrc/batch_prefill_fp8_sm90.cu index 7c8680dc0b..6bf67c9928 100644 --- a/csrc/batch_prefill_fp8_sm90.cu +++ b/csrc/batch_prefill_fp8_sm90.cu @@ -29,6 +29,11 @@ template +cudaError_t BatchFP8PrefillWithRaggedKVCacheDispatched(Params& params, bool enable_pdl, + cudaStream_t stream); + } // namespace flashinfer using namespace flashinfer; @@ -50,7 +55,7 @@ Array BatchPrefillWithKVCacheSM90Plan( flashinfer::PrefillPlanSM90Info plan_info; - cudaSetDevice(float_workspace_buffer.device().device_id); + ffi::CUDADeviceGuard device_guard(float_workspace_buffer.device().device_id); const cudaStream_t stream = get_stream(float_workspace_buffer.device()); cudaError_t status = PrefillSM90Plan( @@ -78,7 +83,94 @@ void BatchPrefillWithRaggedKVCacheSM90Run(ffi::TensorView float_workspace_buffer int64_t window_left, bool enable_pdl // placeholder ADDITIONAL_FUNC_PARAMS) { - return; // TODO: Implement this function + PrefillPlanSM90Info plan_info; + plan_info.FromVector(std::vector(plan_info_vec.begin(), plan_info_vec.end())); + + if (maybe_lse.has_value()) { + const auto& lse = maybe_lse.value(); + TVM_FFI_ICHECK_EQ(lse.size(0), q.size(0)); + TVM_FFI_ICHECK_EQ(lse.size(1), q.size(1)); + } + + void* float_buffer_ptr = float_workspace_buffer.data_ptr(); + void* int_buffer_ptr = int_workspace_buffer.data_ptr(); + + int64_t head_dim_qk = q.size(2); + int64_t head_dim_vo = v.size(2); + + QKVLayout kv_layout = static_cast(layout); + + cudaSetDevice(float_workspace_buffer.device().device_id); + const cudaStream_t stream = get_stream(float_workspace_buffer.device()); + const MaskMode mask_mode = static_cast(mask_mode_code); + bool use_swa = window_left != -1; + + DISPATCH_context( + DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, USE_SLIDING_WINDOW, + USE_LOGITS_SOFT_CAP, AttentionVariant, RaggedParams, PagedParams, [&] { + RaggedParams params; + + params.q_ptr = static_cast(q.data_ptr()); + params.k_ptr = static_cast(k.data_ptr()); + params.v_ptr = static_cast(v.data_ptr()); + params.o_ptr = static_cast(o.data_ptr()); + params.lse_ptr = maybe_lse ? static_cast(maybe_lse.value().data_ptr()) : nullptr; + params.q_stride_n = q.stride(0); + params.q_stride_h = q.stride(1); + params.o_stride_n = o.stride(0); + params.o_stride_h = o.stride(1); + if (kv_layout == QKVLayout::kNHD) { + params.k_stride_n = k.stride(0); + params.k_stride_h = k.stride(1); + params.v_stride_n = v.stride(0); + params.v_stride_h = v.stride(1); + } else { + params.k_stride_h = k.stride(0); + params.k_stride_n = k.stride(1); + params.v_stride_h = v.stride(0); + params.v_stride_n = v.stride(1); + } + params.nnz_qo = q.size(0); + params.nnz_kv = k.size(0); + params.num_qo_heads = q.size(1); + params.num_kv_heads = k.size(1); + params.group_size = params.num_qo_heads / params.num_kv_heads; + params.window_left = window_left; + params.causal = mask_mode_code == 1; + params.qo_tile_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_tile_indices_offset); + params.qo_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_indptr_offset); + params.kv_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_indptr_offset); + params.qo_lens = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_len_offset); + params.kv_lens = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_len_offset); + params.head_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.head_indices_offset); + params.work_indptr = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.work_indptr_offset); + params.batch_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.batch_indices_offset); + + ADDITIONAL_PARAMS_SETTER + + // Not support various head_dim for now + static_assert(HEAD_DIM_QK == HEAD_DIM_VO, "head_dim_qk and head_dim_vo should be the same"); + // Currently only support same quantization precision + static_assert(std::is_same_v); + + bool same_schedule_for_all_heads = plan_info.same_schedule_for_all_heads; + DISPATCH_BOOL(same_schedule_for_all_heads, SAME_SCHEDULER_FOR_ALL_HEADS, [&] { + cudaError_t status = + BatchFP8PrefillWithRaggedKVCacheDispatched(params, enable_pdl, + stream); + + TVM_FFI_ICHECK(status == cudaSuccess) + << "BatchPrefillWithRaggedKVCacheSM90Run failed with error: " + << cudaGetErrorString(status); + return true; + }); + }); } void BatchPrefillWithPagedKVCacheSM90Run( @@ -111,7 +203,7 @@ void BatchPrefillWithPagedKVCacheSM90Run( void* float_buffer_ptr = float_workspace_buffer.data_ptr(); void* int_buffer_ptr = int_workspace_buffer.data_ptr(); - cudaSetDevice(float_workspace_buffer.device().device_id); + ffi::CUDADeviceGuard device_guard(float_workspace_buffer.device().device_id); const cudaStream_t stream = get_stream(float_workspace_buffer.device()); const MaskMode mask_mode = static_cast(mask_mode_code); bool use_swa = window_left != -1; @@ -136,12 +228,18 @@ void BatchPrefillWithPagedKVCacheSM90Run( params.k_stride_h = paged_k_cache.stride(2); params.v_stride_n = paged_v_cache.stride(1); params.v_stride_h = paged_v_cache.stride(2); + // For sparse paged KV cache, store the stride between pages + params.k_page_stride = paged_k_cache.stride(0); + params.v_page_stride = paged_v_cache.stride(0); } else { // (num_pages, num_heads, page_size, head_dim) params.k_stride_h = paged_k_cache.stride(1); params.k_stride_n = paged_k_cache.stride(2); params.v_stride_h = paged_v_cache.stride(1); params.v_stride_n = paged_v_cache.stride(2); + // For sparse paged KV cache, store the stride between pages + params.k_page_stride = paged_k_cache.stride(0); + params.v_page_stride = paged_v_cache.stride(0); } params.nnz_qo = q.size(0); params.num_qo_heads = q.size(1); diff --git a/csrc/batch_prefill_jit_binding.cu b/csrc/batch_prefill_jit_binding.cu index da1e1981dc..3dda0f115a 100644 --- a/csrc/batch_prefill_jit_binding.cu +++ b/csrc/batch_prefill_jit_binding.cu @@ -25,7 +25,7 @@ Array BatchPrefillWithKVCachePlan( TensorView kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk, int64_t head_dim_vo, bool causal, int64_t window_left, int64_t fixed_split_size, - bool disable_split_kv); + bool disable_split_kv, int64_t num_colocated_ctas); void BatchPrefillWithRaggedKVCacheRun(TensorView float_workspace_buffer, TensorView int_workspace_buffer, Array plan_info_vec, diff --git a/csrc/batch_prefill_sm90.cu b/csrc/batch_prefill_sm90.cu index 1cf78bab59..564ed6b08c 100644 --- a/csrc/batch_prefill_sm90.cu +++ b/csrc/batch_prefill_sm90.cu @@ -56,7 +56,7 @@ Array BatchPrefillWithKVCacheSM90Plan( flashinfer::PrefillPlanSM90Info plan_info; - cudaSetDevice(float_workspace_buffer.device().device_id); + ffi::CUDADeviceGuard device_guard(float_workspace_buffer.device().device_id); const cudaStream_t stream = get_stream(float_workspace_buffer.device()); cudaError_t status = PrefillSM90Plan( @@ -97,7 +97,7 @@ void BatchPrefillWithRaggedKVCacheSM90Run( QKVLayout kv_layout = static_cast(layout); - cudaSetDevice(float_workspace_buffer.device().device_id); + ffi::CUDADeviceGuard device_guard(float_workspace_buffer.device().device_id); const cudaStream_t stream = get_stream(float_workspace_buffer.device()); const MaskMode mask_mode = static_cast(mask_mode_code); bool use_swa = window_left != -1; @@ -193,7 +193,7 @@ void BatchPrefillWithPagedKVCacheSM90Run( void* float_buffer_ptr = float_workspace_buffer.data_ptr(); void* int_buffer_ptr = int_workspace_buffer.data_ptr(); - cudaSetDevice(float_workspace_buffer.device().device_id); + ffi::CUDADeviceGuard device_guard(float_workspace_buffer.device().device_id); const cudaStream_t stream = get_stream(float_workspace_buffer.device()); const MaskMode mask_mode = static_cast(mask_mode_code); bool use_swa = window_left != -1; @@ -218,13 +218,24 @@ void BatchPrefillWithPagedKVCacheSM90Run( params.k_stride_h = paged_k_cache.stride(2); params.v_stride_n = paged_v_cache.stride(1); params.v_stride_h = paged_v_cache.stride(2); + // For sparse paged KV cache, store the stride between pages + params.k_page_stride = paged_k_cache.stride(0); + params.v_page_stride = paged_v_cache.stride(0); } else { // (num_pages, num_heads, page_size, head_dim) params.k_stride_h = paged_k_cache.stride(1); params.k_stride_n = paged_k_cache.stride(2); params.v_stride_h = paged_v_cache.stride(1); params.v_stride_n = paged_v_cache.stride(2); + // For sparse paged KV cache, store the stride between pages + params.k_page_stride = paged_k_cache.stride(0); + params.v_page_stride = paged_v_cache.stride(0); } + // Sparse mainloop assumes K and V have same strides for efficiency + TVM_FFI_ICHECK_EQ(params.k_page_stride, params.v_page_stride) + << "K and V must have same page stride for sparse attention"; + TVM_FFI_ICHECK_EQ(params.k_stride_n, params.v_stride_n) + << "K and V must have same stride_n for sparse attention"; params.nnz_qo = q.size(0); params.num_qo_heads = q.size(1); params.num_kv_heads = num_kv_heads; diff --git a/csrc/batch_prefill_sm90_customize_config.jinja b/csrc/batch_prefill_sm90_customize_config.jinja index b37ecac60d..640637c7df 100644 --- a/csrc/batch_prefill_sm90_customize_config.jinja +++ b/csrc/batch_prefill_sm90_customize_config.jinja @@ -104,6 +104,11 @@ struct PagedParams { int64_t o_stride_h; int64_t nnz_qo; + // NOTE: For sparse paged KV cache, we need the stride between pages + // This is paged_k_cache.stride(0), not the layout stride + int64_t k_page_stride; // Stride between pages for K + int64_t v_page_stride; // Stride between pages for V + int head_dim; int num_qo_heads; int num_kv_heads; diff --git a/csrc/blackwell_fmha_plan.cu b/csrc/blackwell_fmha_plan.cu index ef9b1475ea..e20b98179e 100644 --- a/csrc/blackwell_fmha_plan.cu +++ b/csrc/blackwell_fmha_plan.cu @@ -21,7 +21,7 @@ void blackwell_fmha_plan(TensorView qo_segment_offsets, TensorView kv_segment_of TensorView work_indptr, TensorView qo_tile_indices, TensorView head_indices, TensorView batch_indices, int64_t qo_tile_size, int64_t num_heads, int64_t num_buckets, bool causal) { - cudaSetDevice(qo_segment_offsets.device().device_id); + ffi::CUDADeviceGuard device_guard(qo_segment_offsets.device().device_id); const cudaStream_t stream = get_stream(qo_tile_indices.device()); int batch_size = qo_segment_offsets.size(0) - 1; diff --git a/csrc/bmm_fp8.cu b/csrc/bmm_fp8.cu index ea8417b617..4de464fac0 100644 --- a/csrc/bmm_fp8.cu +++ b/csrc/bmm_fp8.cu @@ -45,7 +45,7 @@ void bmm_fp8(TensorView A, TensorView B, TensorView D, TensorView A_scale, Tenso auto n = B.size(2); auto lt_handle = reinterpret_cast(cublas_handle); - cudaSetDevice(A.device().device_id); + ffi::CUDADeviceGuard device_guard(A.device().device_id); auto stream = get_stream(A.device()); auto status = flashinfer::bmm_fp8::bmm_fp8_internal_cublaslt( diff --git a/csrc/cascade.cu b/csrc/cascade.cu index 98e4a590dc..4c3a64e95b 100644 --- a/csrc/cascade.cu +++ b/csrc/cascade.cu @@ -41,7 +41,7 @@ void merge_state(TensorView v_a, TensorView s_a, TensorView v_b, TensorView s_b, unsigned int num_heads = v_a.size(1); unsigned int head_dim = v_a.size(2); - cudaSetDevice(v_a.device().device_id); + ffi::CUDADeviceGuard device_guard(v_a.device().device_id); auto stream = get_stream(v_a.device()); bool success = DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(v_a.dtype(), c_type, [&] { @@ -85,7 +85,7 @@ void merge_state_in_place(TensorView v, TensorView s, TensorView v_other, Tensor unsigned int num_heads = v.size(1); unsigned int head_dim = v.size(2); - cudaSetDevice(v.device().device_id); + ffi::CUDADeviceGuard device_guard(v.device().device_id); auto stream = get_stream(v.device()); bool success = DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(v.dtype(), c_type, [&] { cudaError_t status = MergeStateInPlace( @@ -114,7 +114,7 @@ void merge_states(TensorView v, TensorView s, TensorView v_merged, TensorView s_ unsigned int num_heads = v.size(2); unsigned int head_dim = v.size(3); - cudaSetDevice(v.device().device_id); + ffi::CUDADeviceGuard device_guard(v.device().device_id); auto stream = get_stream(v.device()); bool success = DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(v.dtype(), c_type, [&] { cudaError_t status = MergeStates( diff --git a/csrc/cutlass_mla.cu b/csrc/cutlass_mla.cu index f68df30bea..4700fe5f96 100644 --- a/csrc/cutlass_mla.cu +++ b/csrc/cutlass_mla.cu @@ -23,7 +23,7 @@ using namespace flashinfer::attention; void CutlassMLAPagedAttention(ffi::TensorView workspace, ffi::TensorView out, ffi::TensorView lse, ffi::TensorView q_nope_pe, ffi::TensorView ckv_kpe_cache, ffi::TensorView kv_lens, ffi::TensorView page_table) { - cudaSetDevice(q_nope_pe.device().device_id); + ffi::CUDADeviceGuard device_guard(q_nope_pe.device().device_id); const cudaStream_t stream = get_stream(q_nope_pe.device()); int device_index = q_nope_pe.device().device_id; diff --git a/csrc/dsv3_router_gemm.cu b/csrc/dsv3_router_gemm.cu new file mode 100644 index 0000000000..2d44147d97 --- /dev/null +++ b/csrc/dsv3_router_gemm.cu @@ -0,0 +1,152 @@ +#include "flashinfer/gemm/dsv3_router_gemm.cuh" +#include "tvm_ffi_utils.h" + +namespace flashinfer::trtllm_dsv3_router_gemm { +template +void invokeRouterGemm(float* output, T const* mat_a, T const* mat_b, cudaStream_t stream, + bool use_pdl = false) { + constexpr int VPT = 16 / sizeof(T); + constexpr int kBlockSize = 128; + cudaLaunchConfig_t config; + config.gridDim = kNumExperts; + config.blockDim = kBlockSize; + config.dynamicSmemBytes = 0; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = use_pdl; + config.numAttrs = 1; + config.attrs = attrs; + auto status = cudaLaunchKernelEx( + &config, router_gemm_kernel, output, + mat_a, mat_b); + TVM_FFI_ICHECK(status == cudaSuccess) + << "cudaLaunchKernelEx failed with error code " << cudaGetErrorString(status); +} + +template void invokeRouterGemm<__nv_bfloat16, 1, 256, 7168>(float*, __nv_bfloat16 const*, + __nv_bfloat16 const*, cudaStream_t, + bool); + +template void invokeRouterGemm<__nv_bfloat16, 2, 256, 7168>(float*, __nv_bfloat16 const*, + __nv_bfloat16 const*, cudaStream_t, + bool); + +template void invokeRouterGemm<__nv_bfloat16, 3, 256, 7168>(float*, __nv_bfloat16 const*, + __nv_bfloat16 const*, cudaStream_t, + bool); + +template void invokeRouterGemm<__nv_bfloat16, 4, 256, 7168>(float*, __nv_bfloat16 const*, + __nv_bfloat16 const*, cudaStream_t, + bool); + +template void invokeRouterGemm<__nv_bfloat16, 5, 256, 7168>(float*, __nv_bfloat16 const*, + __nv_bfloat16 const*, cudaStream_t, + bool); + +template void invokeRouterGemm<__nv_bfloat16, 6, 256, 7168>(float*, __nv_bfloat16 const*, + __nv_bfloat16 const*, cudaStream_t, + bool); + +template void invokeRouterGemm<__nv_bfloat16, 7, 256, 7168>(float*, __nv_bfloat16 const*, + __nv_bfloat16 const*, cudaStream_t, + bool); + +template void invokeRouterGemm<__nv_bfloat16, 8, 256, 7168>(float*, __nv_bfloat16 const*, + __nv_bfloat16 const*, cudaStream_t, + bool); + +template void invokeRouterGemm<__nv_bfloat16, 9, 256, 7168>(float*, __nv_bfloat16 const*, + __nv_bfloat16 const*, cudaStream_t, + bool); + +template void invokeRouterGemm<__nv_bfloat16, 10, 256, 7168>(float*, __nv_bfloat16 const*, + __nv_bfloat16 const*, cudaStream_t, + bool); + +template void invokeRouterGemm<__nv_bfloat16, 11, 256, 7168>(float*, __nv_bfloat16 const*, + __nv_bfloat16 const*, cudaStream_t, + bool); + +template void invokeRouterGemm<__nv_bfloat16, 12, 256, 7168>(float*, __nv_bfloat16 const*, + __nv_bfloat16 const*, cudaStream_t, + bool); + +template void invokeRouterGemm<__nv_bfloat16, 13, 256, 7168>(float*, __nv_bfloat16 const*, + __nv_bfloat16 const*, cudaStream_t, + bool); + +template void invokeRouterGemm<__nv_bfloat16, 14, 256, 7168>(float*, __nv_bfloat16 const*, + __nv_bfloat16 const*, cudaStream_t, + bool); + +template void invokeRouterGemm<__nv_bfloat16, 15, 256, 7168>(float*, __nv_bfloat16 const*, + __nv_bfloat16 const*, cudaStream_t, + bool); + +template void invokeRouterGemm<__nv_bfloat16, 16, 256, 7168>(float*, __nv_bfloat16 const*, + __nv_bfloat16 const*, cudaStream_t, + bool); + +template +struct LoopUnroller { + static void unroll(int num_tokens, float* output, __nv_bfloat16 const* input, + __nv_bfloat16 const* weights, cudaStream_t stream, bool launch_with_pdl) { + if (num_tokens == kBegin) { + invokeRouterGemm<__nv_bfloat16, kBegin, kNumExperts, kHiddenDim>(output, input, weights, + stream, launch_with_pdl); + } else { + LoopUnroller::unroll( + num_tokens, output, input, weights, stream, launch_with_pdl); + } + } +}; + +template +struct LoopUnroller { + static void unroll(int num_tokens, float* output, __nv_bfloat16 const* input, + __nv_bfloat16 const* weights, cudaStream_t stream, bool launch_with_pdl) { + if (num_tokens == kEnd) { + invokeRouterGemm<__nv_bfloat16, kEnd, kNumExperts, kHiddenDim>(output, input, weights, stream, + launch_with_pdl); + } else { + throw std::invalid_argument("Invalid num_tokens, only supports 1 to 16"); + } + } +}; + +void dsv3_router_gemm_op(TensorView mat_a, TensorView mat_b, TensorView out, bool launch_with_pdl) { + int const num_tokens = mat_a.sizes()[0]; + int const num_experts = mat_b.sizes()[1]; + int const hidden_dim = mat_a.sizes()[1]; + auto const out_dtype_ = out.dtype(); + auto const data_type = mat_a.dtype(); + constexpr int kNumExperts = 256; + constexpr int kHiddenDim = 7168; + std::vector output_size = {mat_a.sizes()[0], mat_b.sizes()[1]}; + TVM_FFI_ICHECK(mat_a.dim() == 2 && mat_b.dim() == 2) << "mat_a and mat_b must be 2D tensors"; + TVM_FFI_ICHECK(mat_a.strides()[1] == 1 && out.strides()[1] == 1) + << "mat_a and out must be row-major"; + TVM_FFI_ICHECK(mat_b.strides()[0] == 1) << "mat_b must be column-major"; + auto stream = get_stream(mat_a.device()); + bool use_custom_kernel = false; + if (num_tokens >= 1 && num_tokens <= 16 && num_experts == kNumExperts && + hidden_dim == kHiddenDim && encode_dlpack_dtype(data_type) == bfloat16_code && + encode_dlpack_dtype(out_dtype_) == float32_code) { + use_custom_kernel = true; + } + + if (use_custom_kernel) { + LoopUnroller<1, 16, kNumExperts, kHiddenDim>::unroll( + num_tokens, reinterpret_cast(out.data_ptr()), + reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()), + reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()), stream, launch_with_pdl); + } else { + TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported input tensor size"; + } +} + +TVM_FFI_DLL_EXPORT_TYPED_FUNC(dsv3_router_gemm_op, + flashinfer::trtllm_dsv3_router_gemm::dsv3_router_gemm_op); + +} // namespace flashinfer::trtllm_dsv3_router_gemm diff --git a/csrc/flashinfer_page_binding.cu b/csrc/flashinfer_page_binding.cu index dbab4f5cb8..97105712f7 100644 --- a/csrc/flashinfer_page_binding.cu +++ b/csrc/flashinfer_page_binding.cu @@ -27,12 +27,5 @@ void append_paged_mla_kv_cache(TensorView append_ckv, TensorView append_kpe, TensorView kpe_cache, TensorView kv_indices, TensorView kv_indptr, TensorView kv_last_page_len); -void block_sparse_indices_to_vector_sparse_offsets( - TensorView block_sparse_indices, TensorView block_sparse_indptr, - TensorView vector_sparse_offsets, TensorView vector_sparse_indptr, TensorView kv_len_arr, - int64_t stride_block, int64_t stride_n, int64_t batch_size, int64_t block_size); - TVM_FFI_DLL_EXPORT_TYPED_FUNC(append_paged_kv_cache, append_paged_kv_cache); TVM_FFI_DLL_EXPORT_TYPED_FUNC(append_paged_mla_kv_cache, append_paged_mla_kv_cache); -TVM_FFI_DLL_EXPORT_TYPED_FUNC(block_sparse_indices_to_vector_sparse_offsets, - block_sparse_indices_to_vector_sparse_offsets); diff --git a/csrc/flashinfer_rope_binding.cu b/csrc/flashinfer_rope_binding.cu index 23124064d8..94809da735 100644 --- a/csrc/flashinfer_rope_binding.cu +++ b/csrc/flashinfer_rope_binding.cu @@ -45,9 +45,19 @@ void rope_quantize(TensorView q_rope_in, TensorView k_rope_in, TensorView q_nope TensorView pos_ids, double quant_scale_q, double quant_scale_kv, bool interleave, bool enable_pdl); +void rope_quantize_append_paged_kv_cache( + TensorView q_rope_in, TensorView k_rope_in, TensorView q_nope_in, TensorView k_nope_in, + TensorView v_in, TensorView q_rope_out, TensorView q_nope_out, TensorView cos_sin_cache, + TensorView pos_ids, TensorView k_cache, TensorView v_cache, TensorView ckv_cache, + TensorView kpe_cache, TensorView kv_indices, TensorView kv_indptr, TensorView batch_indices, + TensorView positions, int64_t kv_layout_code, int64_t page_size, double quant_scale_q, + double quant_scale_kv, bool interleave, bool enable_pdl); + TVM_FFI_DLL_EXPORT_TYPED_FUNC(apply_rope, apply_rope); TVM_FFI_DLL_EXPORT_TYPED_FUNC(apply_llama31_rope, apply_llama31_rope); TVM_FFI_DLL_EXPORT_TYPED_FUNC(apply_rope_pos_ids, apply_rope_pos_ids); TVM_FFI_DLL_EXPORT_TYPED_FUNC(apply_llama31_rope_pos_ids, apply_llama31_rope_pos_ids); TVM_FFI_DLL_EXPORT_TYPED_FUNC(apply_rope_pos_ids_cos_sin_cache, apply_rope_pos_ids_cos_sin_cache); TVM_FFI_DLL_EXPORT_TYPED_FUNC(rope_quantize, rope_quantize); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(rope_quantize_append_paged_kv_cache, + rope_quantize_append_paged_kv_cache); diff --git a/csrc/flashinfer_xqa_binding.cu b/csrc/flashinfer_xqa_binding.cu index 003a23a5f6..8bcbafafd6 100644 --- a/csrc/flashinfer_xqa_binding.cu +++ b/csrc/flashinfer_xqa_binding.cu @@ -17,37 +17,26 @@ #include "tvm_ffi_utils.h" #if MLA_WRAPPER -void xqa_wrapper_mla(int64_t multiProcessorCount, double qScale, TensorView output, TensorView q, -#if PAGED_KV_CACHE_LAYOUT == 1 - TensorView kCacheVLLM, TensorView vCacheVLLM, -#else - TensorView pool, -#endif - TensorView kvCachePageList, int64_t maxSeqLen, TensorView seqLen, - int64_t batchSize, TensorView kvCacheScale, TensorView semaphores, - TensorView scratch); +void xqa_wrapper_mla(int64_t multiProcessorCount, double qScale, + tvm::ffi::Optional qScaleTensor, TensorView output, TensorView q, + TensorView kCacheVLLM, TensorView vCacheVLLM, TensorView kvCachePageList, + int64_t maxSeqLen, TensorView seqLen, int64_t batchSize, double kvCacheScale, + tvm::ffi::Optional kvScaleTensor, TensorView semaphores, + TensorView scratch, bool enable_pdl); TVM_FFI_DLL_EXPORT_TYPED_FUNC(xqa_wrapper_mla, xqa_wrapper_mla); #else void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbKHeads, - int64_t slidingWinSize, double qScale, TensorView output, -#if LOW_PREC_OUTPUT - TensorView rcpOutScale, -#endif - TensorView q, tvm::ffi::Optional attentionSinks, -#if PAGED_KV_CACHE_LAYOUT == 1 - TensorView kCacheVLLM, TensorView vCacheVLLM, -#else - TensorView pool, -#endif - TensorView kvCachePageList, int64_t maxSeqLen, TensorView seqLen, - int64_t batchSize, TensorView kvCacheScale, -#if SPEC_DEC - int64_t qSeqLen, TensorView qCuSeqLens, TensorView mask, -#endif - TensorView semaphores, TensorView scratch); + int64_t slidingWinSize, double qScale, tvm::ffi::Optional qScaleTensor, + TensorView output, double rcpOutScale, TensorView q, + tvm::ffi::Optional attentionSinks, TensorView kCacheVLLM, + TensorView vCacheVLLM, TensorView kvCachePageList, int64_t maxSeqLen, + TensorView seqLen, int64_t batchSize, double kvCacheScale, + tvm::ffi::Optional kvScaleTensor, int64_t qSeqLen, + tvm::ffi::Optional mask, TensorView semaphores, TensorView scratch, + bool enable_pdl); TVM_FFI_DLL_EXPORT_TYPED_FUNC(xqa_wrapper, xqa_wrapper); diff --git a/csrc/fmhaReduction.cu b/csrc/fmhaReduction.cu index 1f1ca8c755..e329e1c14b 100644 --- a/csrc/fmhaReduction.cu +++ b/csrc/fmhaReduction.cu @@ -34,7 +34,7 @@ namespace kernels { template __global__ void __launch_bounds__(NumThreadsPerCta, 2) - fmhaReductionKernel(KernelParams const params, int32_t numCtasForReduction, + fmhaReductionKernel(KernelParams const params, bool sparseMla, int32_t numCtasForReduction, int32_t numCtasForAllHeads, int32_t numHeadDimCtasV) { // clang-format off // The shape of partialO buffer: [batchSize, numHeadCtas, numCtasQ, numCtasKv, TileSizePerCtaQ, headDimPerCta]. @@ -64,10 +64,25 @@ __global__ void __launch_bounds__(NumThreadsPerCta, 2) // The number of validRows. int32_t const numValidRows{TileSizePerCtaQ}; + // The seqOffsetQ. + int32_t const seqOffsetQ{params.ptrCumSeqLensQ == nullptr ? batchIdx * params.mMaxSeqLenQ + : params.ptrCumSeqLensQ[batchIdx]}; + // The seqLenQ. + int32_t const seqLenQ{params.ptrCumSeqLensQ == nullptr + ? params.mMaxSeqLenQ + : (params.ptrCumSeqLensQ[batchIdx + 1] - seqOffsetQ)}; + // Early exit if ctaIdxQ >= seqLenQ, where each CTA processes one tokenQ. + if (ctaIdxQ >= seqLenQ) { + return; + } // The actual number of seqLenKv. int32_t seqLenKv{params.ptrSeqLensKv[batchIdx]}; // Consider the causal-mask speculative decoding. seqLenKv = seqLenKv - ((params.mMaxSeqLenQ - 1) - ctaIdxQ); + // Consider sparseMlaTopK. + if (sparseMla) { + seqLenKv = min(seqLenKv, params.mSparseMlaTopK); + } // The actual number of CtasKv (TileSizeKv is always 128 for now). int32_t numCtasKv{min((seqLenKv + 127) / 128, params.mMaxNumCtasKv)}; @@ -336,7 +351,7 @@ void runFmhaReduction(TllmGenFmhaKernelMetaInfo const& kernelMeta, KernelParams config.numAttrs = 1; // Select the kernel function pointer. - void (*kernel)(KernelParams const, int32_t, int32_t, int32_t) = nullptr; + void (*kernel)(KernelParams const, bool, int32_t, int32_t, int32_t) = nullptr; if (headDimPerCtaV == 128) { SELECT_FMHA_REDUCTION_KERNEL(128); } else if (headDimPerCtaV == 256) { @@ -346,8 +361,8 @@ void runFmhaReduction(TllmGenFmhaKernelMetaInfo const& kernelMeta, KernelParams } // Launch the kernel. - cudaLaunchKernelEx(&config, kernel, params, numCtasForReduction, numCtasForAllHeads, - numHeadDimCtasV); + cudaLaunchKernelEx(&config, kernel, params, kernelMeta.mSparseMla, numCtasForReduction, + numCtasForAllHeads, numHeadDimCtasV); cudaError_t err = cudaGetLastError(); FLASHINFER_CHECK(err == cudaSuccess, "Failed to launch kernel: ", cudaGetErrorString(err)); } diff --git a/csrc/fmha_cutlass_sm100.cu b/csrc/fmha_cutlass_sm100.cu index c50116fa7f..08f1235adf 100644 --- a/csrc/fmha_cutlass_sm100.cu +++ b/csrc/fmha_cutlass_sm100.cu @@ -96,7 +96,7 @@ void FMHACutlassSM100Run(ffi::TensorView workspace_buffer, ffi::TensorView q, ff int v_stride_n = v.stride(0); int v_stride_h = v.stride(1); - cudaSetDevice(qo_segment_offsets.device().device_id); + ffi::CUDADeviceGuard device_guard(qo_segment_offsets.device().device_id); const cudaStream_t stream = get_stream(o.device()); DISPATCH_context(DTypeIn, DTypeOut, HEAD_DIM_QK, HEAD_DIM_VO, MASK_MODE, [&] { diff --git a/csrc/fmha_v2/convert.cu b/csrc/fmha_v2/convert.cu new file mode 100644 index 0000000000..345bd008f9 --- /dev/null +++ b/csrc/fmha_v2/convert.cu @@ -0,0 +1,196 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +__global__ void convert_int32_to_int8_kernel(void* dst, void const* src, size_t n, float scale) { + // The step. + size_t step = (size_t)gridDim.x * blockDim.x; + + // Iterate over the elements. + for (size_t ii = blockIdx.x * blockDim.x + threadIdx.x; ii < n / 4; ii += step) { + // Load 4 integers. + int4 tmp = reinterpret_cast(src)[ii]; + + // Convert to float and scale. + float x = static_cast(tmp.x) * scale; + float y = static_cast(tmp.y) * scale; + float z = static_cast(tmp.z) * scale; + float w = static_cast(tmp.w) * scale; + + // Convert to int8. + uint32_t a; + asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(a) : "f"(x)); + uint32_t b; + asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(b) : "f"(y)); + uint32_t c; + asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(c) : "f"(z)); + uint32_t d; + asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(d) : "f"(w)); + + // Compact. + char4 out; + out.x = reinterpret_cast(a); + out.y = reinterpret_cast(b); + out.z = reinterpret_cast(c); + out.w = reinterpret_cast(d); + + // Store. + reinterpret_cast(dst)[ii] = reinterpret_cast(out); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void run_conversion_int32_to_int8(void* dst, void const* src, int s, int b, int h, int d, + float scale) { + size_t n = (size_t)s * b * h * d; + convert_int32_to_int8_kernel<<<512, 256>>>(dst, src, n, scale); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ inline typename fmha::Uint_from_size_in_bytes::Type pack_float4( + float4 const& f); + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +__device__ inline uint2 pack_float4(float4 const& f) { + return fmha::float4_to_half4(f.x, f.y, f.z, f.w); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +__device__ inline uint2 pack_float4(float4 const& f) { + return fmha::float4_to_16bit_x4(f.x, f.y, f.z, f.w); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +__device__ inline uint32_t pack_float4(float4 const& f) { + return fmha::float4_to_e4m3x4(f.x, f.y, f.z, f.w); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +template <> +__device__ inline uint32_t pack_float4(float4 const& f) { + return fmha::float4_to_e5m2x4(f.x, f.y, f.z, f.w); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void convert_fp32_to_T_kernel(void* dst, void const* src, size_t n, float scale = 1.f) { + using Dst = typename fmha::Uint_from_size_in_bytes::Type; + + // The step. + size_t step = (size_t)gridDim.x * blockDim.x; + + // Iterate over the elements. + for (size_t ii = blockIdx.x * blockDim.x + threadIdx.x; ii < n / 4; ii += step) { + // Load 4 floats. + float4 tmp = reinterpret_cast(src)[ii]; + // Scale. + tmp.x *= scale; + tmp.y *= scale; + tmp.z *= scale; + tmp.w *= scale; + // Convert to 4 Ts. + auto out = pack_float4(tmp); + + // Store. + reinterpret_cast(dst)[ii] = reinterpret_cast(out); + } +} + +template +__global__ void convert_T_to_fp32_kernel(void* dst, void const* src, size_t n, float scale = 1.f) { + using Src = typename fmha::Uint_from_size_in_bytes::Type; + + union { + Src raw; + T elt[4]; + } data; + + // The step. + size_t step = (size_t)gridDim.x * blockDim.x; + + // Iterate over the elements. + for (size_t ii = blockIdx.x * blockDim.x + threadIdx.x; ii < n / 4; ii += step) { + // Load 4 floats. + data.raw = reinterpret_cast(src)[ii]; + float4 out; + // Scale. + out.x = float(data.elt[0]) * scale; + out.y = float(data.elt[1]) * scale; + out.z = float(data.elt[2]) * scale; + out.w = float(data.elt[3]) * scale; + + // Store. + reinterpret_cast(dst)[ii] = reinterpret_cast(out); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void run_conversion_fp32_to_fp16(void* dst, void const* src, int s, int b, int h, int d) { + // No need to expose the scale factor for FP16/FP32. + size_t n = (size_t)s * b * h * d; + convert_fp32_to_T_kernel<<<512, 256>>>(dst, src, n, 1.f); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void run_conversion_fp32_to_bf16(void* dst, void const* src, int s, int b, int h, int d) { + // No need to expose the scale factor for FP16/FP32. + size_t n = (size_t)s * b * h * d; + convert_fp32_to_T_kernel<<<512, 256>>>(dst, src, n, 1.f); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void run_conversion_fp32_to_e4m3(void* dst, void const* src, size_t n, float scale_o) { + convert_fp32_to_T_kernel<<<512, 256>>>(dst, src, n, scale_o); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void run_conversion_e4m3_to_fp32(void* dst, void const* src, size_t n, float scale_o) { + convert_T_to_fp32_kernel<<<512, 256>>>(dst, src, n, scale_o); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void run_conversion_fp32_to_e4m3(void* dst, void const* src, int s, int b, int h, int d, + float scale_o) { + run_conversion_fp32_to_e4m3(dst, src, s * b * h * d, scale_o); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void run_conversion_fp32_to_e5m2(void* dst, void const* src, size_t n, float scale_o) { + convert_fp32_to_T_kernel<<<512, 256>>>(dst, src, n, scale_o); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void run_conversion_e5m2_to_fp32(void* dst, void const* src, size_t n, float scale_o) { + convert_T_to_fp32_kernel<<<512, 256>>>(dst, src, n, scale_o); +} diff --git a/csrc/fmha_v2/fmha/alibi_params.h b/csrc/fmha_v2/fmha/alibi_params.h new file mode 100644 index 0000000000..bee7ea1be9 --- /dev/null +++ b/csrc/fmha_v2/fmha/alibi_params.h @@ -0,0 +1,50 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +namespace fmha { + +struct AlibiParams { + constexpr static int round_down_to_power_two(int x) { + x = x | (x >> 1); + x = x | (x >> 2); + x = x | (x >> 4); + x = x | (x >> 8); + x = x | (x >> 16); + return x - (x >> 1); + } + + AlibiParams() = default; + + AlibiParams(int h, float scale_after_alibi = 1.f) : scale_after_alibi(scale_after_alibi) { + h_pow_2 = round_down_to_power_two(h); + alibi_neg4_div_h = -4.0f / h_pow_2; + } + + AlibiParams(int h, int s, int tp_size, int rank, float scale_after_alibi = 1.f) + : AlibiParams(h * tp_size, scale_after_alibi) { + head_idx_offset = h * rank; + sequence_pos_offset = s * rank; + } + + int h_pow_2{}; + float alibi_neg4_div_h{}; + float scale_after_alibi{}; + // Could be simplified to `int rank` derive the others as `num_heads * rank, s * rank` at + // runtime, but this makes assumptions about the layout downstream + // (e.g. downstream may only split across the head dimension, so s would be the full sequence) + int head_idx_offset = 0; + int sequence_pos_offset = 0; +}; + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/fragment.h b/csrc/fmha_v2/fmha/fragment.h new file mode 100644 index 0000000000..01bdc0fdac --- /dev/null +++ b/csrc/fmha_v2/fmha/fragment.h @@ -0,0 +1,2311 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include +#include + +#include + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_ldg {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_ldg<1> { + template + static inline __device__ void ldg(Fragment& f, int ii, void const* ptr) { + uint8_t tmp; + fmha::ldg(tmp, ptr); + f.u8(ii) = tmp; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_ldg<2> { + template + static inline __device__ void ldg(Fragment& f, int ii, void const* ptr) { + uint16_t tmp; + fmha::ldg(tmp, ptr); + f.u16(ii) = tmp; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_ldg<4> { + template + static inline __device__ void ldg(Fragment& f, int ii, void const* ptr) { + uint32_t tmp; + fmha::ldg(tmp, ptr); + f.reg(ii) = tmp; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_ldg<8> { + template + static inline __device__ void ldg(Fragment& f, int ii, void const* ptr) { + uint2 tmp; + fmha::ldg(tmp, ptr); + f.reg(2 * ii + 0) = tmp.x; + f.reg(2 * ii + 1) = tmp.y; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_ldg<16> { + template + static inline __device__ void ldg(Fragment& f, int ii, void const* ptr) { + uint4 tmp; + fmha::ldg(tmp, ptr); + f.reg(4 * ii + 0) = tmp.x; + f.reg(4 * ii + 1) = tmp.y; + f.reg(4 * ii + 2) = tmp.z; + f.reg(4 * ii + 3) = tmp.w; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_lds {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_lds<2> { + template + static inline __device__ void lds(Fragment& f, int ii, uint32_t ptr) { + uint16_t tmp; + fmha::lds(tmp, ptr); + f.u16(ii) = tmp; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_lds<4> { + template + static inline __device__ void lds(Fragment& f, int ii, uint32_t ptr) { + uint32_t tmp; + fmha::lds(tmp, ptr); + f.reg(ii) = tmp; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_lds<8> { + template + static inline __device__ void lds(Fragment& f, int ii, uint32_t ptr) { + uint2 tmp; + fmha::lds(tmp, ptr); + f.reg(2 * ii + 0) = tmp.x; + f.reg(2 * ii + 1) = tmp.y; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_lds<16> { + template + static inline __device__ void lds(Fragment& f, int ii, uint32_t ptr) { + uint4 tmp; + fmha::lds(tmp, ptr); + f.reg(4 * ii + 0) = tmp.x; + f.reg(4 * ii + 1) = tmp.y; + f.reg(4 * ii + 2) = tmp.z; + f.reg(4 * ii + 3) = tmp.w; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// template<> +// struct Fragment_lds<32> { +// template< typename Fragment > +// static inline __device__ void lds(Fragment &f, int ii, uint32_t ptr) { +// uint4 tmp; +// fmha::lds(tmp, ptr); +// f.reg(8*ii+0) = tmp.x; +// f.reg(8*ii+1) = tmp.y; +// f.reg(8*ii+2) = tmp.z; +// f.reg(8*ii+3) = tmp.w; +// +// fmha::lds(tmp, static_cast(ptr)+sizeof(uint4)); +// f.reg(8*ii+4) = tmp.x; +// f.reg(8*ii+5) = tmp.y; +// f.reg(8*ii+6) = tmp.z; +// f.reg(8*ii+7) = tmp.w; +// } +// }; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_stg {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_stg<1> { + template + static inline __device__ void stg(void* ptr, Fragment const& f, int ii = 0) { + fmha::stg(ptr, f.u8(ii)); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_stg<2> { + template + static inline __device__ void stg(void* ptr, Fragment const& f, int ii = 0) { + fmha::stg(ptr, f.u16(ii)); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_stg<4> { + template + static inline __device__ void stg(void* ptr, Fragment const& f, int ii = 0) { + fmha::stg(ptr, f.reg(ii)); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_stg<8> { + template + static inline __device__ void stg(void* ptr, Fragment const& f, int ii = 0) { + uint2 tmp; + tmp.x = f.reg(2 * ii + 0); + tmp.y = f.reg(2 * ii + 1); + fmha::stg(ptr, tmp); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_stg<16> { + template + static inline __device__ void stg(void* ptr, Fragment const& f, int ii = 0) { + uint4 tmp; + tmp.x = f.reg(4 * ii + 0); + tmp.y = f.reg(4 * ii + 1); + tmp.z = f.reg(4 * ii + 2); + tmp.w = f.reg(4 * ii + 3); + fmha::stg(ptr, tmp); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_base_ { + // The data type. + using Data_type = Data_type_; + // default input type + using Input_type_ = Data_type_; + + // Does it store the array of elements. + enum { HAS_ELTS = BITS_PER_ELT_ >= 8 }; + + // The number of elements. + enum { NUM_ELTS = NUM_ELTS_ }; + + // The size of element in bits. + enum { BITS_PER_ELT = BITS_PER_ELT_ }; + + // The size of byte of a single register. + enum { BYTES_PER_REG = 4 }; + + // The size in bits. + enum { BITS_PER_REG = BYTES_PER_REG * 8 }; + + // The number of registers needed to store the fragment. + enum { NUM_REGS = Div_up::VALUE }; + + // The size in bytes (as returned by sizeof(Fragment_base<>). + enum { SIZE_IN_BYTES = NUM_REGS * BYTES_PER_REG }; + + // The alignment. + enum { ALIGNMENT = ALIGNMENT_ > 0 ? ALIGNMENT_ : Min::VALUE }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The type of the elements. + typename Data_type_, + // The number of elements. + int NUM_ELTS_, + // The size of each element in bits. + int BITS_PER_ELT_, + // The alignment if you want to force a value -- use 0 otherwise. + int ALIGNMENT_, + // The base class. + typename Base_ = Fragment_base_> +struct alignas(static_cast(Base_::ALIGNMENT)) Fragment_base : public Base_ { + // The size of a load/store. + enum { BYTES_PER_LOAD_STORE = Base_::NUM_REGS * sizeof(uint32_t) }; + + // Clear the fragment. Using PTX in that code seems to produce better SASS... + inline __device__ void clear() { +#pragma unroll + for (int ii = 0; ii < Base_::NUM_REGS; ++ii) { + asm volatile("mov.u32 %0, 0; \n" : "=r"(this->reg(ii)) :); + } + } + + // Load from global memory. + inline __device__ void ldg(void const* ptr) { + Fragment_ldg::ldg(*this, 0, ptr); + } + + // Load from shared memory. + inline __device__ void lds(uint32_t ptr) { + Fragment_lds::lds(*this, 0, ptr); + } + + // Immutable access to a register. + inline __device__ uint32_t const& reg(int ii) const { return this->regs_[ii]; } + + // Mutable access to a register. + inline __device__ uint32_t& reg(int ii) { return this->regs_[ii]; } + + // Set the fragment with a scalar + inline __device__ void set(uint32_t value) { +#pragma unroll + for (int ii = 0; ii < Base_::NUM_REGS; ++ii) { + this->reg(ii) = value; + } + } + + // Store to global memory. + inline __device__ void stg(void* ptr) const { + Fragment_stg::stg(ptr, *this, 0); + } + + // Immutable access to a byte. + inline __device__ uint8_t u8(int ii) const { + return reinterpret_cast(&this->regs_[0])[ii]; + } + + // Mutable access to a u8. + inline __device__ uint8_t& u8(int ii) { return reinterpret_cast(&this->regs_[0])[ii]; } + + // Immutable access to a half-word.. + inline __device__ uint16_t u16(int ii) const { + return reinterpret_cast(&this->regs_[0])[ii]; + } + + // Mutable access to a half-word. + inline __device__ uint16_t& u16(int ii) { + return reinterpret_cast(&this->regs_[0])[ii]; + } + + // Immutable access to a word. + inline __device__ uint32_t u32(int ii) const { + return reinterpret_cast(&this->regs_[0])[ii]; + } + + // Mutable access to a word. + inline __device__ uint32_t& u32(int ii) { + return reinterpret_cast(&this->regs_[0])[ii]; + } + + // Immutable access to a word. + inline __device__ uint2 u64(int ii) const { + return reinterpret_cast(&this->regs_[0])[ii]; + } + + // Mutable access to a word. + inline __device__ uint2& u64(int ii) { return reinterpret_cast(&this->regs_[0])[ii]; } + + // The storage in registers. + // + // NOTE: Instead of using only an array of uint32_t, we could use a union so we could either + // access the registers or the elements. We found that for: + // + // union { + // uint16_t elts_[4]; uint32_t regs_[2]; + // }; + // + // The compiler does not always produce a final structure of 8B. So, for the moment we are + // going to go only with the regs_ array and use reinterpret_cast<> to access elements (see + // below). It may be worth revisiting that when time permits. + uint32_t regs_[Base_::NUM_REGS]; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment : public Fragment_base { + // Immutable access to the elements. + inline __device__ Data_type_ const& elt(int ii) const { + return reinterpret_cast(&this->regs_[0])[ii]; + } + + // Mutable access to the elements. + inline __device__ Data_type_& elt(int ii) { + return reinterpret_cast(&this->regs_[0])[ii]; + } + + // Immutable access to the elements with a cast. + template + inline __device__ Cast_type const& elt_as(int ii) const { + return reinterpret_cast(&this->regs_[0])[ii]; + } + + // Mutable access to the elements. + template + inline __device__ Cast_type& elt_as(int ii) { + return reinterpret_cast(&this->regs_[0])[ii]; + } + + // Add another fragment. + inline __device__ void add(Fragment const& other) { +#pragma unroll + for (int ii = 0; ii < NUM_ELTS_; ++ii) { + this->elt(ii) += other.elt(ii); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_a {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_a : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_a : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_a : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_a : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_a : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_a : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_a : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_a : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_a : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_a : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_a : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_a : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_b {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_b : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_b : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_b : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_b : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_b : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_b : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_b : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_b : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_b : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_b : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_b : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_b : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_b : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_accumulator {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_accumulator : public Fragment { + // The traits. + using Traits = Volta_hmma_fp16_traits; + // The base class. + using Base = Fragment; + + // The fragments. + using Fragment_a = Fragment_a; + using Fragment_b = Fragment_b; + + // HMMA. + inline __device__ void mma(Fragment_a const& a, Fragment_b const& b) { + asm volatile( + "mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5}, \n" + " {%6, %7}, \n" + " {%0, %1, %2, %3}; \n" + : "+r"(this->reg(0)), "+r"(this->reg(1)), "+r"(this->reg(2)), "+r"(this->reg(3)) + : "r"(a.reg(0)), "r"(a.reg(1)), "r"(b.reg(0)), "r"(b.reg(1))); + asm volatile( + "mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5}, \n" + " {%6, %7}, \n" + " {%0, %1, %2, %3}; \n" + : "+r"(this->reg(0)), "+r"(this->reg(1)), "+r"(this->reg(2)), "+r"(this->reg(3)) + : "r"(a.reg(2)), "r"(a.reg(3)), "r"(b.reg(2)), "r"(b.reg(3))); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_accumulator : public Fragment { + // The base class. + using Base = Fragment; + + // The fragments. + using Fragment_a = Fragment_a; + using Fragment_b = Fragment_b; + + // HMMA. + inline __device__ void mma(Fragment_a const& a, Fragment_b const& b) { + // K = 0..3 for threads 0..7 and 16..23 and K = 4..7 for 8..15 and 24..31. + asm volatile( + "mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5}, \n" + " {%6, %7}, \n" + " {%0, %1, %2, %3}; \n" + : "+r"(this->reg(0)), "+r"(this->reg(1)), "+r"(this->reg(2)), "+r"(this->reg(3)) + : "r"(a.reg(0)), "r"(a.reg(1)), "r"(b.reg(0)), "r"(b.reg(1))); + asm volatile( + "mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5}, \n" + " {%6, %7}, \n" + " {%0, %1, %2, %3}; \n" + : "+r"(this->reg(4)), "+r"(this->reg(5)), "+r"(this->reg(6)), "+r"(this->reg(7)) + : "r"(a.reg(0)), "r"(a.reg(1)), "r"(b.reg(2)), "r"(b.reg(3))); + + // K = 8..11 for threads 0..7 and 16..23 and K = 12..15 for 8..15 and 24..31. + asm volatile( + "mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5}, \n" + " {%6, %7}, \n" + " {%0, %1, %2, %3}; \n" + : "+r"(this->reg(0)), "+r"(this->reg(1)), "+r"(this->reg(2)), "+r"(this->reg(3)) + : "r"(a.reg(2)), "r"(a.reg(3)), "r"(b.reg(4)), "r"(b.reg(5))); + asm volatile( + "mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5}, \n" + " {%6, %7}, \n" + " {%0, %1, %2, %3}; \n" + : "+r"(this->reg(4)), "+r"(this->reg(5)), "+r"(this->reg(6)), "+r"(this->reg(7)) + : "r"(a.reg(2)), "r"(a.reg(3)), "r"(b.reg(6)), "r"(b.reg(7))); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_accumulator : public Fragment { + // The base class. + using Base = Fragment; + + // The fragments. + using Fragment_a = Fragment_a; + using Fragment_b = Fragment_b; + + // IMMA. + inline __device__ void mma(Fragment_a const& a, Fragment_b const& b) { +#pragma unroll + for (int i = 0; i < 4; ++i) { + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 \n" + " {%0, %1}, \n" + " {%2}, \n" + " {%3}, \n" + " {%0, %1}; \n" + : "+r"(this->reg(2 * i + 0)), "+r"(this->reg(2 * i + 1)) + : "r"(a.reg(i / 2)), "r"(b.reg(i % 2))); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_accumulator : public Fragment { + // Do the HMMA. + template + inline __device__ void mma(Fragment_a const& a, + Fragment_b const& b) { + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 \n" + " {%0, %1}, \n" + " {%2, %3}, \n" + " {%4}, \n" + " {%0, %1}; \n" + : "+r"(reg(0)), "+r"(reg(1)) + : "r"(a.reg(0)), "r"(a.reg(1)), "r"(b.reg(0))); + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 \n" + " {%0, %1}, \n" + " {%2, %3}, \n" + " {%4}, \n" + " {%0, %1}; \n" + : "+r"(reg(2)), "+r"(reg(3)) + : "r"(a.reg(0)), "r"(a.reg(1)), "r"(b.reg(1))); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_accumulator : public Fragment { + // The base class. + using Base = Fragment; + + // Add two fragments. + template + inline __device__ void add(Other_fragment_ const& other) { + for (int ii = 0; ii < Base::NUM_ELTS; ++ii) { + this->elt(ii) = this->elt(ii) + other.elt(ii); + } + } + + inline __device__ void mul(float const other) { + for (int ii = 0; ii < Base::NUM_ELTS; ++ii) { + this->elt(ii) *= other; + } + } + + // Do the HMMA. + template + inline __device__ void mma(Fragment_a const& a, + Fragment_b const& b) { + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5}, \n" + " {%6}, \n" + " {%0, %1, %2, %3}; \n" + : "+f"(elt(0)), "+f"(elt(1)), "+f"(elt(2)), "+f"(elt(3)) + : "r"(a.reg(0)), "r"(a.reg(1)), "r"(b.reg(0))); + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5}, \n" + " {%6}, \n" + " {%0, %1, %2, %3}; \n" + : "+f"(elt(4)), "+f"(elt(5)), "+f"(elt(6)), "+f"(elt(7)) + : "r"(a.reg(0)), "r"(a.reg(1)), "r"(b.reg(1))); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_accumulator : public Fragment { + // The base class. + using Base = Fragment; + + // The fragments. + using Fragment_a = Fragment_a; + using Fragment_b = Fragment_b; + + // IMMA. + inline __device__ void mma(Fragment_a const& a, Fragment_b const& b) { +#pragma unroll + for (int i = 0; i < 4; ++i) { + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 \n" + " {%0, %1}, \n" + " {%2}, \n" + " {%3}, \n" + " {%0, %1}; \n" + : "+r"(this->reg(2 * i + 0)), "+r"(this->reg(2 * i + 1)) + : "r"(a.reg(i / 2)), "r"(b.reg(i % 2))); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_accumulator : public Fragment { + // Do the HMMA. + template + inline __device__ void mma(Fragment_a const& a, + Fragment_b const& b) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 \n" + " {%0, %1}, \n" + " {%2, %3, %4, %5}, \n" + " {%6, %7}, \n" + " {%0, %1}; \n" + : "+r"(reg(0)), "+r"(reg(1)) + : "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3)), "r"(b.reg(0)), "r"(b.reg(1))); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 \n" + " {%0, %1}, \n" + " {%2, %3, %4, %5}, \n" + " {%6, %7}, \n" + " {%0, %1}; \n" + : "+r"(reg(2)), "+r"(reg(3)) + : "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3)), "r"(b.reg(2)), "r"(b.reg(3))); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// BF16 MMA must accumulate with at least FP32 +template <> +struct Fragment_accumulator : public Fragment { + // Do the HMMA. + template + inline __device__ void mma(Fragment_a const& a, + Fragment_b const& b) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 \n" + " {%0, %1}, \n" + " {%2, %3, %4, %5}, \n" + " {%6, %7}, \n" + " {%0, %1}; \n" + : "+r"(reg(0)), "+r"(reg(1)) + : "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3)), "r"(b.reg(0)), "r"(b.reg(1))); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 \n" + " {%0, %1}, \n" + " {%2, %3, %4, %5}, \n" + " {%6, %7}, \n" + " {%0, %1}; \n" + : "+r"(reg(2)), "+r"(reg(3)) + : "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3)), "r"(b.reg(2)), "r"(b.reg(3))); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_accumulator : public Fragment { + // The base class. + using Base = Fragment; + + // Add two fragments. + template + inline __device__ void add(Other_fragment_ const& other) { + for (int ii = 0; ii < Base::NUM_ELTS; ++ii) { + this->elt(ii) = this->elt(ii) + other.elt(ii); + } + } + + inline __device__ void mul(float const other) { + for (int ii = 0; ii < Base::NUM_ELTS; ++ii) { + this->elt(ii) *= other; + } + } + + // Do the HMMA. + template + inline __device__ void mma(Fragment_a const& a, + Fragment_b const& b) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5, %6, %7}, \n" + " {%8, %9}, \n" + " {%0, %1, %2, %3}; \n" + : "+f"(elt(0)), "+f"(elt(1)), "+f"(elt(2)), "+f"(elt(3)) + : "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3)), "r"(b.reg(0)), "r"(b.reg(1))); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5, %6, %7}, \n" + " {%8, %9}, \n" + " {%0, %1, %2, %3}; \n" + : "+f"(elt(4)), "+f"(elt(5)), "+f"(elt(6)), "+f"(elt(7)) + : "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3)), "r"(b.reg(2)), "r"(b.reg(3))); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// BF16 MMA must accumulate with at least FP32 +template <> +struct Fragment_accumulator : public Fragment { + // The base class. + using Base = Fragment; + + // Add two fragments. + template + inline __device__ void add(Other_fragment_ const& other) { + for (int ii = 0; ii < Base::NUM_ELTS; ++ii) { + this->elt(ii) = this->elt(ii) + other.elt(ii); + } + } + + inline __device__ void mul(float const other) { + for (int ii = 0; ii < Base::NUM_ELTS; ++ii) { + this->elt(ii) *= other; + } + } + + // Do the HMMA. + template + inline __device__ void mma(Fragment_a const& a, + Fragment_b const& b) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5, %6, %7}, \n" + " {%8, %9}, \n" + " {%0, %1, %2, %3}; \n" + : "+f"(elt(0)), "+f"(elt(1)), "+f"(elt(2)), "+f"(elt(3)) + : "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3)), "r"(b.reg(0)), "r"(b.reg(1))); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5, %6, %7}, \n" + " {%8, %9}, \n" + " {%0, %1, %2, %3}; \n" + : "+f"(elt(4)), "+f"(elt(5)), "+f"(elt(6)), "+f"(elt(7)) + : "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3)), "r"(b.reg(2)), "r"(b.reg(3))); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_accumulator : public Fragment { + // The base class. + using Base = Fragment; + + // The fragments. + using Fragment_a = Fragment_a; + using Fragment_b = Fragment_b; + + // IMMA. + inline __device__ void mma(Fragment_a const& a, Fragment_b const& b) { +#pragma unroll + for (int i = 0; i < 2; ++i) { + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5, %6, %7}, \n" + " {%8, %9}, \n" + " {%0, %1, %2, %3}; \n" + : "+r"(reg(i * 4 + 0)), "+r"(reg(i * 4 + 1)), "+r"(reg(i * 4 + 2)), "+r"(reg(i * 4 + 3)) + : "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3)), "r"(b.reg(i * 2)), + "r"(b.reg(i * 2 + 1))); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_accumulator : public Fragment { + // The base class. + using Base = Fragment; + + // The fragments. + using Fragment_a = Fragment_a; + using Fragment_b = Fragment_b; + + // IMMA. + inline __device__ void mma(Fragment_a const& a, Fragment_b const& b) { +#pragma unroll + for (int i = 0; i < 2; ++i) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 890 + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5, %6, %7}, \n" + " {%8, %9}, \n" + " {%0, %1, %2, %3}; \n" + : "+r"(reg(i * 4 + 0)), "+r"(reg(i * 4 + 1)), "+r"(reg(i * 4 + 2)), "+r"(reg(i * 4 + 3)) + : "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3)), "r"(b.reg(i * 2)), + "r"(b.reg(i * 2 + 1))); +#else + asm volatile("trap;\n"); +#endif + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_accumulator : public Fragment { + // The base class. + using Base = Fragment; + + // The fragments. + using Fragment_a = Fragment_a; + using Fragment_b = Fragment_b; + + // IMMA. + inline __device__ void mma(Fragment_a const& a, Fragment_b const& b) { +#pragma unroll + for (int i = 0; i < 2; ++i) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 890 + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f16.e4m3.e4m3.f16 \n" + " {%0, %1}, \n" + " {%2, %3, %4, %5}, \n" + " {%6, %7}, \n" + " {%0, %1}; \n" + : "+r"(reg(i * 2 + 0)), "+r"(reg(i * 2 + 1)) + : "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3)), "r"(b.reg(i * 2)), + "r"(b.reg(i * 2 + 1))); +#else + asm volatile("trap;\n"); +#endif + } + } +}; + +template +struct Tile_o_normalizer { + // The fragment accumulator. + using Fragment_accu = Fragment_accumulator; + + // The Mma tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of MMAs in M/N dimensions. + enum { MMAS_M = Mma_tile::MMAS_M }; + + enum { MMAS_N = Mma_tile::VALID_MMAS_N }; + + // The number of rows per thread. + enum { ROWS_PER_THREAD = 2 * MMAS_M }; + + // The number of registers per thread + enum { REGS_PER_THREAD = 4 }; + + // Warps. + enum { WARPS_M = Cta_tile::WARPS_M }; + + enum { WARPS_N = Cta_tile::WARPS_N }; + + enum { WARPS_K = Cta_tile::WARPS_K }; + + // softmax data bytes + enum { BYTES_PER_ELEMENT = sizeof(float) }; + + // Initialize the attention sinks. + template + inline __device__ Tile_o_normalizer(Params const& params, Block_info const& binfo) + : attention_sink_value_(params.attention_sinks != nullptr ? params.attention_sinks[binfo.bidh] + : -FLT_MAX) {} + + // Update the sum when attention sinks are used. + inline __device__ void update_sum(float const (&max)[ROWS_PER_THREAD], + float (&sum)[ROWS_PER_THREAD]) { +#pragma unroll + for (int i = 0; i < ROWS_PER_THREAD; ++i) { + sum[i] += expf(attention_sink_value_ - max[i]); + } + } + + // Update o. + inline __device__ void update(Fragment_accu (&acc_o)[MMAS_M][MMAS_N], + float (&curr_max)[ROWS_PER_THREAD], + float const (&prev_max)[ROWS_PER_THREAD], + float (&sum)[ROWS_PER_THREAD]) { +#ifdef HALF_ACCUMULATION_FOR_FLASH_ATTENTION // Half accumulation +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // Precompute the scaling factors for the 2 rows. + uint32_t alpha[2]; +#pragma unroll + for (int ii = 0; ii < 2; ++ii) { + // The row. + int jj = 2 * mi + ii; + // The multiplier. + curr_max[jj] = fmax(prev_max[jj], curr_max[jj]); + float a = expf(prev_max[jj] - curr_max[jj]); + sum[jj] *= a; + // Convert back to FP16x2. + alpha[ii] = fmha::float2_to_half2(a, a); + } + +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < REGS_PER_THREAD; ++ii) { + // The accumulators in FP16x2. + uint32_t acc_o_pair = acc_o[mi][ni].reg(ii); + + // Apply the scaling. + acc_o_pair = fmha::hmul2(alpha[ii & 1], acc_o_pair); + + // Update the register. + acc_o[mi][ni].reg(ii) = acc_o_pair; + } + } + } +#else // Float accumulation +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // Precompute the scaling factors for the 2 rows. + float alpha[2]; +#pragma unroll + for (int ii = 0; ii < 2; ++ii) { + // The row. + int jj = 2 * mi + ii; + // The multiplier. + curr_max[jj] = fmax(prev_max[jj], curr_max[jj]); + alpha[ii] = expf(prev_max[jj] - curr_max[jj]); + sum[jj] *= alpha[ii]; + } + +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < REGS_PER_THREAD; ++ii) { + // The registers. + float2 acc_o_pair = fmha::half2_to_float2(acc_o[mi][ni].reg(ii)); + + // Do the math in Fp32. + acc_o_pair.x = alpha[ii & 1] * acc_o_pair.x; + acc_o_pair.y = alpha[ii & 1] * acc_o_pair.y; + + // Convert back to Fp16x2. + acc_o[mi][ni].reg(ii) = fmha::float2_to_half2(acc_o_pair); + } + } + } +#endif // defined HALF_ACCUMULATION_FOR_FLASH_ATTENTION + } + + // Update o. + inline __device__ void final_update(Fragment_accu (&acc_o)[MMAS_M][MMAS_N], + float (&sum)[ROWS_PER_THREAD]) { +#ifdef HALF_ACCUMULATION_FOR_FLASH_ATTENTION // Half accumulation +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // Precompute the scaling factors for the 2 rows. + uint32_t beta[2]; +#pragma unroll + for (int ii = 0; ii < 2; ++ii) { + // The row. + int jj = 2 * mi + ii; + float b = (sum[jj] == 0.f || sum[jj] != sum[jj]) ? 1.f : 1.f / sum[jj]; + // Convert back to FP16x2. + beta[ii] = fmha::float2_to_half2(b, b); + } + +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < REGS_PER_THREAD; ++ii) { + // The accumulators in FP16x2. + uint32_t acc_o_pair = acc_o[mi][ni].reg(ii); + + // Apply the scaling. + acc_o_pair = fmha::hmul2(acc_o_pair, beta[ii & 1]); + + // Update the register. + acc_o[mi][ni].reg(ii) = acc_o_pair; + } + } + } +#else // Float accumulation +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // Precompute the scaling factors for the 2 rows. + float beta[2]; +#pragma unroll + for (int ii = 0; ii < 2; ++ii) { + // The row. + int jj = 2 * mi + ii; + // The diviser. + beta[ii] = (sum[jj] == 0.f || sum[jj] != sum[jj]) ? 1.f : 1.f / sum[jj]; + } + +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < REGS_PER_THREAD; ++ii) { + // The registers. + float2 acc_o_pair = fmha::half2_to_float2(acc_o[mi][ni].reg(ii)); + + // Do the math in Fp32. + acc_o_pair.x = acc_o_pair.x * beta[ii & 1]; + acc_o_pair.y = acc_o_pair.y * beta[ii & 1]; + + // Convert back to Fp16x2. + acc_o[mi][ni].reg(ii) = fmha::float2_to_half2(acc_o_pair); + } + } + } +#endif // defined HALF_ACCUMULATION_FOR_FLASH_ATTENTION + } + + // Attention sink value. + float attention_sink_value_; +}; + +template +struct Tile_o_normalizer_fp32 { + // The fragment accumulator. + using Fragment_accu = Fragment_accumulator; + + // The Mma tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of MMAs in the M dimension. + enum { MMAS_M = Mma_tile::MMAS_M }; + + // The number of MMAs in the N dimension. + enum { MMAS_N = Mma_tile::VALID_MMAS_N }; + + // The number of rows per thread. + enum { ROWS_PER_THREAD = 2 * MMAS_M }; + + // The number of registers per thread. + enum { REGS_PER_THREAD = 8 }; + + // Warps. + enum { WARPS_M = Cta_tile::WARPS_M }; + + enum { WARPS_N = Cta_tile::WARPS_N }; + + enum { WARPS_K = Cta_tile::WARPS_K }; + + // softmax data bytes + enum { BYTES_PER_ELEMENT = sizeof(float) }; + + // Initialize the attention sinks. + template + inline __device__ Tile_o_normalizer_fp32(Params const& params, Block_info const& binfo) + : attention_sink_value_(params.attention_sinks != nullptr ? params.attention_sinks[binfo.bidh] + : -FLT_MAX) {} + + // Update the sum when attention sinks are used. + inline __device__ void update_sum(float const (&max)[ROWS_PER_THREAD], + float (&sum)[ROWS_PER_THREAD]) { +#pragma unroll + for (int i = 0; i < ROWS_PER_THREAD; ++i) { + sum[i] += expf(attention_sink_value_ - max[i]); + } + } + + // Update o. + inline __device__ void update(Fragment_accu (&acc_o)[MMAS_M][MMAS_N], + float (&curr_max)[ROWS_PER_THREAD], + float const (&prev_max)[ROWS_PER_THREAD], + float (&sum)[ROWS_PER_THREAD]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // Precompute the scaling factors for the 2 rows. + float alpha[2]; +#pragma unroll + for (int ii = 0; ii < 2; ++ii) { + // The row. + int jj = 2 * mi + ii; + // The multiplier. + curr_max[jj] = fmax(prev_max[jj], curr_max[jj]); + alpha[ii] = expf(prev_max[jj] - curr_max[jj]); + sum[jj] *= alpha[ii]; + } + +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < REGS_PER_THREAD; ++ii) { + // The register for O. + float acc_o_f = acc_o[mi][ni].elt(ii); + // Compute the next accumulator. + acc_o_f = alpha[(ii & 2) / 2] * acc_o_f; + // Update the accumulator. + acc_o[mi][ni].elt(ii) = acc_o_f; + } + } + } + } + + // Update o after P * V + inline __device__ void final_update(Fragment_accu (&acc_o)[MMAS_M][MMAS_N], + float (&sum)[ROWS_PER_THREAD]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // Precompute the scaling factors for the 2 rows. + float beta[2]; +#pragma unroll + for (int ii = 0; ii < 2; ++ii) { + // The row. + int jj = 2 * mi + ii; + + // The diviser. + beta[ii] = (sum[jj] == 0.f || sum[jj] != sum[jj]) ? 1.f : 1.f / sum[jj]; + } + +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < REGS_PER_THREAD; ++ii) { + // The register for O. + float acc_o_f = acc_o[mi][ni].elt(ii); + // Compute the next accumulator. + acc_o_f = acc_o_f * beta[(ii & 2) / 2]; + // Update the accumulator. + acc_o[mi][ni].elt(ii) = acc_o_f; + } + } + } + } + + // Attention sink value. + float attention_sink_value_; +}; + +template +struct Tile_o_normalizer + : public Tile_o_normalizer_fp32 { + // The traits. + using Traits = fmha::Ampere_hmma_fp32_traits; + // The base class. + using Base = Tile_o_normalizer_fp32; + + // The ctor. + template + inline __device__ Tile_o_normalizer(Params const& params, Block_info const& binfo) + : Base(params, binfo) {} +}; + +template +struct Tile_o_normalizer + : public Tile_o_normalizer_fp32 { + // The traits. + using Traits = fmha::Ampere_hmma_bf16_traits; + // The base class. + using Base = Tile_o_normalizer_fp32; + + // The ctor. + template + inline __device__ Tile_o_normalizer(Params const& params, Block_info const& binfo) + : Base(params, binfo) {} +}; + +// The attention sinks are not enabled for Volta. +template +struct Tile_o_normalizer { + // The traits. + using Traits = Volta_hmma_fp16_16x16x16_traits; + + // The fragments. + using Fragment_accu = Fragment_accumulator; + + // The Mma tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of MMAs in M/N dimensions. + enum { MMAS_M = Mma_tile::MMAS_M }; + + enum { MMAS_N = Mma_tile::VALID_MMAS_N }; + + // The number of rows per thread. + enum { ROWS_PER_THREAD = MMAS_M }; + + // The number of registers per thread + enum { REGS_PER_THREAD = 8 }; + + // Update o. + inline __device__ void update(Fragment_accu (&acc_o)[MMAS_M][MMAS_N], + float (&curr_max)[ROWS_PER_THREAD], + float const (&prev_max)[ROWS_PER_THREAD], + float (&sum)[ROWS_PER_THREAD]) { +#ifdef HALF_ACCUMULATION_FOR_FLASH_ATTENTION // Half accumulation + +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // Precompute the scaling factors. + uint32_t alpha; + // Update the curr_max. + curr_max[mi] = fmax(prev_max[mi], curr_max[mi]); + // The multiplier. + float a = expf(prev_max[mi] - curr_max[mi]); + // The accumulated sum. + sum[mi] *= a; + // Convert back to FP16. + alpha = fmha::float2_to_half2(a, a); + +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < REGS_PER_THREAD; ++ii) { + // The accumulators packed in FP16x2. + uint32_t acc_o_pair = acc_o[mi][ni].reg(ii); + + // Apply the scaling. + acc_o_pair = fmha::hmul2(acc_o_pair, alpha); + + // Update the register. + acc_o[mi][ni].reg(ii) = acc_o_pair; + } + } + } +#else // Float accumulation +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // Update the curr_max. + curr_max[mi] = fmax(prev_max[mi], curr_max[mi]); + // The multiplier. + float alpha = expf(prev_max[mi] - curr_max[mi]); + // The accumulated sum. + sum[mi] *= alpha; + +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < REGS_PER_THREAD; ++ii) { + // The accumulators. Convert from FP16x2 to FP32x2. + float2 acc_o_pair = fmha::half2_to_float2(acc_o[mi][ni].reg(ii)); + + // Apply the scaling. + acc_o_pair.x = alpha * acc_o_pair.x; + acc_o_pair.y = alpha * acc_o_pair.y; + + // Update the register after converting back to FP16x2. + acc_o[mi][ni].reg(ii) = fmha::float2_to_half2(acc_o_pair); + } + } + } +#endif // defined HALF_ACCUMULATION_FOR_FLASH_ATTENTION + } + + // Update o. + inline __device__ void final_update(Fragment_accu (&acc_o)[MMAS_M][MMAS_N], + float const (&sum)[ROWS_PER_THREAD]) { +#ifdef HALF_ACCUMULATION_FOR_FLASH_ATTENTION // Half accumulation +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // Precompute the scaling factors. + uint32_t beta; + // The divisor. + float b = (sum[mi] == 0.f || sum[mi] != sum[mi]) ? 1.f : 1.f / sum[mi]; + // Convert back to FP16. + beta = fmha::float2_to_half2(b, b); + +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < REGS_PER_THREAD; ++ii) { + // The accumulators packed in FP16x2. + uint32_t acc_o_pair = acc_o[mi][ni].reg(ii); + + // Apply the scaling. + acc_o_pair = fmha::hmul2(acc_o_pair, beta); + + // Update the register. + acc_o[mi][ni].reg(ii) = acc_o_pair; + } + } + } +#else // Float accumulation +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // The divisor. + float beta = (sum[mi] == 0.f || sum[mi] != sum[mi]) ? 1.f : 1.f / sum[mi]; + +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < REGS_PER_THREAD; ++ii) { + // The registers. + float2 acc_o_pair = fmha::half2_to_float2(acc_o[mi][ni].reg(ii)); + + // Do the math in Fp32. + acc_o_pair.x = acc_o_pair.x * beta; + acc_o_pair.y = acc_o_pair.y * beta; + + // Convert back to Fp16x2. + acc_o[mi][ni].reg(ii) = fmha::float2_to_half2(acc_o_pair); + } + } + } +#endif // defined HALF_ACCUMULATION_FOR_FLASH_ATTENTION + } +}; + +template +struct Tile_o_normalizer + : public Tile_o_normalizer_fp32 { + // The traits. + using Traits = fmha::Ada_qmma_e4m3_fp32_traits; + // The base class. + using Base = Tile_o_normalizer_fp32; + + // The ctor. + template + inline __device__ Tile_o_normalizer(Params const& params, Block_info const& binfo) + : Base(params, binfo) {} + + // Update the sum. + inline __device__ void update_sum(float const (&max)[Base::ROWS_PER_THREAD], + float (&sum)[Base::ROWS_PER_THREAD]) { +// Take the log2f(Traits::SOFTMAX_FP_QUANT_SCALE) into account as the same scale has been applied to +// sum. +#pragma unroll + for (int i = 0; i < Base::ROWS_PER_THREAD; ++i) { + sum[i] += expf(this->attention_sink_value_ - max[i]) * Traits::SOFTMAX_FP_QUANT_SCALE; + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Tile_o_normalizer + : public Tile_o_normalizer_fp32 { + // The traits. + using Traits = fmha::Ada_qmma_e4m3_fp32_traits; + // The base class. + using Base = Tile_o_normalizer_fp32; + + using Fragment_accu = Fragment_accumulator; + + // The Mma tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of MMAs in the M dimension. + enum { MMAS_M = Mma_tile::MMAS_M }; + + // The number of MMAs in the N dimension. + enum { MMAS_N = Mma_tile::VALID_MMAS_N }; + + // The number of registers per thread. + enum { REGS_PER_THREAD = 8 }; + + // The ctor. + template + inline __device__ Tile_o_normalizer(Params const& params, Block_info const& binfo) + : Base(params, binfo) {} + + inline __device__ void merge(Fragment_accu (&acc_dst)[MMAS_M][MMAS_N], + Fragment_accu (&acc_src)[MMAS_M][MMAS_N]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < REGS_PER_THREAD; ++ii) { + acc_dst[mi][ni].elt(ii) += acc_src[mi][ni].elt(ii); + } + } + } + } + + template + inline __device__ void move_to_first_block(Params const& params, int bidb, int bidh) { + int scale_iter = bidb * params.h * params.sage.v.max_nblock + bidh * params.sage.v.max_nblock; + + params_scale_v_iter = reinterpret_cast(params.sage.v.scales + scale_iter); + params_scale_v_ = __ldg(params_scale_v_iter); + } + + inline __device__ void move_to_next_block() { + params_scale_v_iter += 1; + params_scale_v_ = __ldg(params_scale_v_iter); + } + + inline __device__ void apply_scale(Fragment_accu (&acc_o)[MMAS_M][MMAS_N]) { + float const scale = reinterpret_cast(params_scale_v_); + +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < REGS_PER_THREAD; ++ii) { + float acc_o_f = acc_o[mi][ni].elt(ii); + acc_o_f = scale * acc_o_f; + acc_o[mi][ni].elt(ii) = acc_o_f; + } + } + } + } + + float const* params_scale_v_iter; + float params_scale_v_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax_saver { + // The number of MMAs in M/N dimensions. + enum { MMAS_M = Mma_tile::MMAS_M }; + + enum { MMAS_N = Mma_tile::VALID_MMAS_N }; + + // The number of rows per thread. + enum { ROWS_PER_THREAD = 2 * MMAS_M }; + + // The number of registers per thread + enum { REGS_PER_THREAD = 4 }; + + // Warps. + enum { WARPS_M = Cta_tile::WARPS_M }; + + enum { WARPS_N = Cta_tile::WARPS_N }; + + enum { WARPS_K = Cta_tile::WARPS_K }; + + // softmax data bytes + enum { BYTES_PER_ELEMENT = sizeof(float) }; + + // Ctor. + template + inline __device__ Softmax_saver(Params const& params, Block_info const& binfo) + : actual_q_len_(binfo.actual_q_seqlen), + softmax_sum_ptr_(reinterpret_cast(params.softmax_stats_ptr)), + softmax_stats_stride_in_bytes_(params.softmax_stats_stride_in_bytes) { + softmax_max_ptr_ = reinterpret_cast(params.softmax_stats_ptr); + + int warp = threadIdx.x / Cta_tile::THREADS_PER_WARP; + int lane = threadIdx.x % Cta_tile::THREADS_PER_WARP; + // MMA row0 index (8x4 thread layout) + + int m_per_mma = 32 / Mma_tile::THREADS_PER_MMA_N * 2; + row0_ = (warp % WARPS_M) * m_per_mma + (lane / 4); + // Decide whether to store the lse values + store_softmax_ = (lane % 4 == 0 && int(warp / WARPS_M) == 0); + + // assume fixed seq length for the batch + size_t const bh_offset = (binfo.sum_s * params.h + binfo.bidh) * sizeof(float) * 2; + softmax_max_ptr_ += bh_offset + row0_ * params.softmax_stats_stride_in_bytes; + softmax_sum_ptr_ += bh_offset + row0_ * params.softmax_stats_stride_in_bytes + sizeof(float); + }; + + inline __device__ void store(int q_loop, float* p_sum, float* p_max) { + if (store_softmax_) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + float sum0 = p_sum[mi * 2]; + float sum1 = p_sum[mi * 2 + 1]; + float max0 = p_max[mi * 2]; + float max1 = p_max[mi * 2 + 1]; + + int row_offset = q_loop * Cta_tile::M + mi * Mma_tile::M_PER_MMA_PER_CTA; + if (row0_ + row_offset < actual_q_len_) { + fmha::stg(softmax_max_ptr_ + row_offset * softmax_stats_stride_in_bytes_, max0); + fmha::stg(softmax_sum_ptr_ + row_offset * softmax_stats_stride_in_bytes_, sum0); + } + if (row0_ + row_offset + 8 < actual_q_len_) { + fmha::stg(softmax_max_ptr_ + (row_offset + 8) * softmax_stats_stride_in_bytes_, max1); + fmha::stg(softmax_sum_ptr_ + (row_offset + 8) * softmax_stats_stride_in_bytes_, sum1); + } + } + } + } + + // ptr (total_token_q, h, 2) float + char* softmax_sum_ptr_ = nullptr; + char* softmax_max_ptr_ = nullptr; + + // the first row's idx + int row0_; + // actual seq length + int const actual_q_len_ = 0; + int const softmax_stats_stride_in_bytes_ = 0; + + // store lse or not + bool store_softmax_ = false; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Flash Attention: default applied to Turing, Ampere fp16 traits + +template +struct Fragment_updater { + // The fragment accumulator. + using Fragment_accu = Fragment_accumulator; + + // The Mma tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of MMAs in M/N dimensions. + enum { MMAS_M = Mma_tile::MMAS_M }; + + enum { MMAS_N = Mma_tile::VALID_MMAS_N }; + + // The number of rows per thread. + enum { ROWS_PER_THREAD = 2 * MMAS_M }; + + // The number of registers per thread + enum { REGS_PER_THREAD = 4 }; + + // Warps. + enum { WARPS_M = Cta_tile::WARPS_M }; + + enum { WARPS_N = Cta_tile::WARPS_N }; + + enum { WARPS_K = Cta_tile::WARPS_K }; + + // softmax data bytes + enum { BYTES_PER_ELEMENT = sizeof(float) }; + + // Ctor. + template + inline __device__ Fragment_updater(Params const& params, Block_info const& binfo) + : actual_seqlen_(binfo.actual_seqlen), + softmax_lse_ptr_(reinterpret_cast(params.lse_ptr)) // [b, h, s] + { + int warp = threadIdx.x / Cta_tile::THREADS_PER_WARP; + int lane = threadIdx.x % Cta_tile::THREADS_PER_WARP; + // MMA row0 index (8x4 thread layout) + row0_ = (warp % WARPS_M) * Mma_tile::M_PER_MMA + (lane / 4); + // Decide whether to store the lse values + store_lse_ = (lane % 4 == 0 && int(warp / WARPS_M) == 0); + + // assume fixed seq length for the batch + size_t const bh_offset = + (binfo.bidb * params.h + binfo.bidh) * binfo.actual_seqlen * BYTES_PER_ELEMENT; + softmax_lse_ptr_ += bh_offset + row0_ * BYTES_PER_ELEMENT; + }; + + // init all statistics + inline __device__ Fragment_updater() { +#pragma unroll + for (int row_i = 0; row_i < ROWS_PER_THREAD; ++row_i) { + curr_max_[row_i] = -HUGE_VALF; + prev_max_[row_i] = -HUGE_VALF; + prev_sum_[row_i] = 0.0f; + curr_sum_[row_i] = 0.0f; + } + } + + // Update o. + inline __device__ void update_o(Fragment_accu (&acc_o)[MMAS_M][MMAS_N], + Fragment_accu const (&local_acc_o)[MMAS_M][MMAS_N]) { +#ifdef HALF_ACCUMULATION_FOR_FLASH_ATTENTION // Half accumulation +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // Precompute the scaling factors for the 2 rows. + uint32_t alpha[2], beta[2]; +#pragma unroll + for (int ii = 0; ii < 2; ++ii) { + // The row. + int jj = 2 * mi + ii; + // The multiplier. + float a = prev_sum_[jj] * __expf(prev_max_[jj] - curr_max_[jj]); + // The diviser. + float b = + (curr_sum_[jj] == 0.f || curr_sum_[jj] != curr_sum_[jj]) ? 1.f : 1.f / curr_sum_[jj]; + // Convert back to FP16x2. + alpha[ii] = fmha::float2_to_half2(a, a); + beta[ii] = fmha::float2_to_half2(b, b); + } + +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < REGS_PER_THREAD; ++ii) { + // The accumulators in FP16x2. + uint32_t local_o_pair = local_acc_o[mi][ni].reg(ii); + uint32_t acc_o_pair = acc_o[mi][ni].reg(ii); + + // Apply the scaling. + acc_o_pair = fmha::hfma2(alpha[ii & 1], acc_o_pair, local_o_pair); + acc_o_pair = fmha::hmul2(acc_o_pair, beta[ii & 1]); + + // Update the register. + acc_o[mi][ni].reg(ii) = acc_o_pair; + } + } + } +#else // Float accumulation +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // Precompute the scaling factors for the 2 rows. + float alpha[2], beta[2]; +#pragma unroll + for (int ii = 0; ii < 2; ++ii) { + // The row. + int jj = 2 * mi + ii; + // The multiplier. + alpha[ii] = prev_sum_[jj] * __expf(prev_max_[jj] - curr_max_[jj]); + // The diviser. + beta[ii] = + (curr_sum_[jj] == 0.f || curr_sum_[jj] != curr_sum_[jj]) ? 1.f : 1.f / curr_sum_[jj]; + } + +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < REGS_PER_THREAD; ++ii) { + // The registers. + float2 local_o_pair = fmha::half2_to_float2(local_acc_o[mi][ni].reg(ii)); + float2 acc_o_pair = fmha::half2_to_float2(acc_o[mi][ni].reg(ii)); + + // Do the math in Fp32. + acc_o_pair.x = (alpha[ii & 1] * acc_o_pair.x + local_o_pair.x) * beta[ii & 1]; + acc_o_pair.y = (alpha[ii & 1] * acc_o_pair.y + local_o_pair.y) * beta[ii & 1]; + + // Convert back to Fp16x2. + acc_o[mi][ni].reg(ii) = fmha::float2_to_half2(acc_o_pair); + } + } + } +#endif // defined HALF_ACCUMULATION_FOR_FLASH_ATTENTION + } + + // Update max scale + inline __device__ void update_acc_max() { +#pragma unroll + for (int row_i = 0; row_i < ROWS_PER_THREAD; ++row_i) { + float pre_curr_max_ = curr_max_[row_i]; + curr_max_[row_i] = fmaxf(prev_max_[row_i], curr_max_[row_i]); + prev_max_[row_i] = pre_curr_max_; + } + } + + // Update max scale + inline __device__ void update_acc_sum() { +#pragma unroll + for (int row_i = 0; row_i < ROWS_PER_THREAD; ++row_i) { + float pre_curr_sum_ = curr_sum_[row_i]; + curr_sum_[row_i] = + __expf(prev_max_[row_i] - curr_max_[row_i]) * curr_sum_[row_i] + prev_sum_[row_i]; + prev_sum_[row_i] = pre_curr_sum_; + } + } + + inline __device__ void store(int q_loop) { + if (store_lse_) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + float row0_lse = curr_max_[mi * 2] + __logf(curr_sum_[mi * 2]); + float row1_lse = curr_max_[mi * 2 + 1] + __logf(curr_sum_[mi * 2 + 1]); + int row_offset = q_loop * Cta_tile::M + mi * Mma_tile::M_PER_MMA_PER_CTA; + if (row0_ + row_offset < actual_seqlen_) { + fmha::stg(softmax_lse_ptr_ + row_offset * BYTES_PER_ELEMENT, row0_lse); + } + if (row0_ + row_offset + 8 < actual_seqlen_) { + fmha::stg(softmax_lse_ptr_ + (row_offset + 8) * BYTES_PER_ELEMENT, row1_lse); + } + } + } + } + + // Update scales. + float curr_max_[ROWS_PER_THREAD] = {-HUGE_VALF}; + float curr_sum_[ROWS_PER_THREAD] = {0}; + float prev_max_[ROWS_PER_THREAD] = {-HUGE_VALF}; + ; + float prev_sum_[ROWS_PER_THREAD] = {0}; + + // ptr + char* softmax_lse_ptr_ = nullptr; + + // the first row's idx + int row0_ = 0; + // actual seq length + int const actual_seqlen_ = 0; + + // store lse or not + bool store_lse_ = false; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Flash attention to update the accumulators in the 2nd GEMM when we accumulate in FP32. +// Support both hmma_fp32 and ampere_hmma_bf16 +template +struct Fragment_updater_ampere_fp32 { + // The fragment accumulator. + using Fragment_accu = Fragment_accumulator; + + // The Mma tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of MMAs in the M dimension. + enum { MMAS_M = Mma_tile::MMAS_M }; + + // The number of MMAs in the N dimension. + enum { MMAS_N = Mma_tile::VALID_MMAS_N }; + + // The number of rows per thread. + enum { ROWS_PER_THREAD = 2 * MMAS_M }; + + // The number of registers per thread. + enum { REGS_PER_THREAD = 8 }; + + // Warps. + enum { WARPS_M = Cta_tile::WARPS_M }; + + enum { WARPS_N = Cta_tile::WARPS_N }; + + enum { WARPS_K = Cta_tile::WARPS_K }; + + // softmax data bytes + enum { BYTES_PER_ELEMENT = sizeof(float) }; + + // Ctor. + template + inline __device__ Fragment_updater_ampere_fp32(Params const& params, Block_info const& binfo) + : actual_seqlen_(binfo.actual_seqlen), + softmax_lse_ptr_(reinterpret_cast(params.lse_ptr)) // [b, h, s] + { + int warp = threadIdx.x / Cta_tile::THREADS_PER_WARP; + int lane = threadIdx.x % Cta_tile::THREADS_PER_WARP; + // MMA row0 index (8x4 thread layout) + row0_ = (warp % WARPS_M) * Mma_tile::M_PER_MMA + (lane / 4); + // Decide whether to store the lse values + store_lse_ = (lane % 4 == 0 && int(warp / WARPS_M) == 0); + + // assume fixed seq length for the batch + size_t const bh_offset = + (binfo.bidb * params.h + binfo.bidh) * binfo.actual_seqlen * BYTES_PER_ELEMENT; + softmax_lse_ptr_ += bh_offset + row0_ * BYTES_PER_ELEMENT; + }; + + // init all statistics + inline __device__ Fragment_updater_ampere_fp32() { +#pragma unroll + for (int row_i = 0; row_i < ROWS_PER_THREAD; ++row_i) { + curr_max_[row_i] = -HUGE_VALF; + prev_max_[row_i] = -HUGE_VALF; + prev_sum_[row_i] = 0.0f; + curr_sum_[row_i] = 0.0f; + } + } + + // Update o after P * V + inline __device__ void update_o(Fragment_accu (&acc_o)[MMAS_M][MMAS_N], + Fragment_accu const (&local_acc_o)[MMAS_M][MMAS_N]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // Precompute the scaling factors for the 2 rows. + float alpha[2], beta[2]; +#pragma unroll + for (int ii = 0; ii < 2; ++ii) { + // The row. + int jj = 2 * mi + ii; + // The multiplier. + alpha[ii] = prev_sum_[jj] * __expf(prev_max_[jj] - curr_max_[jj]); + // The diviser. + beta[ii] = + (curr_sum_[jj] == 0.f || curr_sum_[jj] != curr_sum_[jj]) ? 1.f : 1.f / curr_sum_[jj]; + } + +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < REGS_PER_THREAD; ++ii) { + // The register from P. + float local_acc_o_f = local_acc_o[mi][ni].elt(ii); + // The register for O. + float acc_o_f = acc_o[mi][ni].elt(ii); + // Compute the next accumulator. + acc_o_f = (alpha[(ii & 2) / 2] * acc_o_f + local_acc_o_f) * beta[(ii & 2) / 2]; + // Update the accumulator. + acc_o[mi][ni].elt(ii) = acc_o_f; + } + } + } + } + + // Update o before P * V + inline __device__ void update_o(Fragment_accu (&acc_o)[MMAS_M][MMAS_N]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // Precompute the scaling factors for the 2 rows. + float alpha[2], beta[2]; +#pragma unroll + for (int ii = 0; ii < 2; ++ii) { + // The row. + int jj = 2 * mi + ii; + // The multiplier. + alpha[ii] = prev_sum_[jj] * __expf(prev_max_[jj] - curr_max_[jj]); + // The diviser. + beta[ii] = + (curr_sum_[jj] == 0.f || curr_sum_[jj] != curr_sum_[jj]) ? 1.f : 1.f / curr_sum_[jj]; + } + +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < REGS_PER_THREAD; ++ii) { + // The register for O. + float acc_o_f = acc_o[mi][ni].elt(ii); + // Compute the next accumulator. + acc_o_f = alpha[(ii & 2) / 2] * acc_o_f * beta[(ii & 2) / 2]; + // Update the accumulator. + acc_o[mi][ni].elt(ii) = acc_o_f; + } + } + } + } + + // Update max scale + inline __device__ void update_acc_max() { +#pragma unroll + for (int ii = 0; ii < ROWS_PER_THREAD; ++ii) { + float curr_max = curr_max_[ii]; + curr_max_[ii] = fmaxf(prev_max_[ii], curr_max); + prev_max_[ii] = curr_max; + } + } + + // Update max scale + inline __device__ void update_acc_sum() { +#pragma unroll + for (int ii = 0; ii < ROWS_PER_THREAD; ++ii) { + float curr_sum = curr_sum_[ii]; + curr_sum_[ii] = __expf(prev_max_[ii] - curr_max_[ii]) * curr_sum_[ii] + prev_sum_[ii]; + prev_sum_[ii] = curr_sum; + } + } + + inline __device__ void store(int q_loop) { + if (store_lse_) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + float row0_lse = curr_max_[mi * 2] + __logf(curr_sum_[mi * 2]); + float row1_lse = curr_max_[mi * 2 + 1] + __logf(curr_sum_[mi * 2 + 1]); + int row_offset = q_loop * Cta_tile::M + mi * Mma_tile::M_PER_MMA_PER_CTA; + if (row0_ + row_offset < actual_seqlen_) { + fmha::stg(softmax_lse_ptr_ + row_offset * BYTES_PER_ELEMENT, row0_lse); + } + if (row0_ + row_offset + 8 < actual_seqlen_) { + fmha::stg(softmax_lse_ptr_ + (row_offset + 8) * BYTES_PER_ELEMENT, row1_lse); + } + } + } + } + + // Update scales. + float curr_max_[ROWS_PER_THREAD] = {-HUGE_VALF}; + float curr_sum_[ROWS_PER_THREAD] = {0}; + float prev_max_[ROWS_PER_THREAD] = {-HUGE_VALF}; + float prev_sum_[ROWS_PER_THREAD] = {0}; + + // ptr + char* softmax_lse_ptr_ = nullptr; + + // the first row's idx + int row0_ = 0; + // actual seq length + int const actual_seqlen_ = 0; + + // store lse or not + bool store_lse_ = false; +}; + +template +struct Fragment_updater + : public Fragment_updater_ampere_fp32 { + // The traits. + using Traits = fmha::Ampere_hmma_fp32_traits; + // The base class. + using Base = Fragment_updater_ampere_fp32; + + // Ctor. + template + inline __device__ Fragment_updater(Params const& params, Block_info const& binfo) + : Base(params, binfo) {} + + // Default ctor + Fragment_updater() = default; +}; + +template +struct Fragment_updater + : public Fragment_updater_ampere_fp32 { + // The traits. + using Traits = fmha::Ampere_hmma_bf16_traits; + // The base class. + using Base = Fragment_updater_ampere_fp32; + + // Ctor. + template + inline __device__ Fragment_updater(Params const& params, Block_info const& binfo) + : Base(params, binfo) {} + + // Default ctor + Fragment_updater() = default; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_updater + : public Fragment_updater_ampere_fp32 { + // The traits. + using Traits = fmha::Turing_hmma_fp32_traits; + // The base class. + using Base = Fragment_updater_ampere_fp32; + + // Ctor. + template + inline __device__ Fragment_updater(Params const& params, Block_info const& binfo) + : Base(params, binfo) {} + + // Default ctor + Fragment_updater() = default; +}; + +template +struct Fragment_updater + : public Fragment_updater_ampere_fp32 { + // The traits. + using Traits = fmha::Ada_qmma_e4m3_fp32_traits; + // The base class. + using Base = Fragment_updater_ampere_fp32; + + // Ctor. + template + inline __device__ Fragment_updater(Params const& params, Block_info const& binfo) + : Base(params, binfo) {} + + // Default ctor + Fragment_updater() = default; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_updater { + // The traits. + using Traits = Volta_hmma_fp16_16x16x16_traits; + + // The fragments. + using Fragment_accu = Fragment_accumulator; + + // The Mma tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of MMAs in M/N dimensions. + enum { MMAS_M = Mma_tile::MMAS_M }; + + enum { MMAS_N = Mma_tile::VALID_MMAS_N }; + + // The number of rows per thread. + enum { ROWS_PER_THREAD = MMAS_M }; + + // The number of registers per thread + enum { REGS_PER_THREAD = 8 }; + + // init all statistics + inline __device__ Fragment_updater() { +#pragma unroll + for (int row_i = 0; row_i < ROWS_PER_THREAD; ++row_i) { + curr_max_[row_i] = -HUGE_VALF; + prev_max_[row_i] = -HUGE_VALF; + prev_sum_[row_i] = 0.0f; + curr_sum_[row_i] = 0.0f; + } + } + + // Update o. + inline __device__ void update_o(Fragment_accu (&acc_o)[MMAS_M][MMAS_N], + Fragment_accu const (&local_acc_o)[MMAS_M][MMAS_N]) { +#ifdef HALF_ACCUMULATION_FOR_FLASH_ATTENTION // Half accumulation + +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // Precompute the scaling factors. + uint32_t alpha, beta; + // The multiplier. + float a = prev_sum_[mi] * __expf(prev_max_[mi] - curr_max_[mi]); + // The diviser. + float b = + (curr_sum_[mi] == 0.f || curr_sum_[mi] != curr_sum_[mi]) ? 1.f : 1.f / curr_sum_[mi]; + // Convert back to FP16. + alpha = fmha::float2_to_half2(a, a); + beta = fmha::float2_to_half2(b, b); + +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < REGS_PER_THREAD; ++ii) { + // The accumulators packed in FP16x2. + uint32_t local_o_pair = local_acc_o[mi][ni].reg(ii); + uint32_t acc_o_pair = acc_o[mi][ni].reg(ii); + + // Apply the scaling. + acc_o_pair = fmha::hmul2(fmha::hfma2(alpha, acc_o_pair, local_o_pair), beta); + + // Update the register. + acc_o[mi][ni].reg(ii) = acc_o_pair; + } + } + } +#else // Float accumulation +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // The multiplier. + float alpha = prev_sum_[mi] * __expf(prev_max_[mi] - curr_max_[mi]); + // The diviser. + float beta = + (curr_sum_[mi] == 0.f || curr_sum_[mi] != curr_sum_[mi]) ? 1.f : 1.f / curr_sum_[mi]; + +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < REGS_PER_THREAD; ++ii) { + // The accumulators. Convert from FP16x2 to FP32x2. + float2 local_o_pair = fmha::half2_to_float2(local_acc_o[mi][ni].reg(ii)); + float2 acc_o_pair = fmha::half2_to_float2(acc_o[mi][ni].reg(ii)); + + // Apply the scaling. + acc_o_pair.x = (alpha * acc_o_pair.x + local_o_pair.x) * beta; + acc_o_pair.y = (alpha * acc_o_pair.y + local_o_pair.y) * beta; + + // Update the register after converting back to FP16x2. + acc_o[mi][ni].reg(ii) = fmha::float2_to_half2(acc_o_pair); + } + } + } +#endif // defined HALF_ACCUMULATION_FOR_FLASH_ATTENTION + } + + // Update max scale + inline __device__ void update_acc_max() { +#pragma unroll + for (int row_i = 0; row_i < ROWS_PER_THREAD; ++row_i) { + float pre_curr_max_ = curr_max_[row_i]; + curr_max_[row_i] = fmaxf(prev_max_[row_i], curr_max_[row_i]); + prev_max_[row_i] = pre_curr_max_; + } + } + + // Update max scale + inline __device__ void update_acc_sum() { +#pragma unroll + for (int row_i = 0; row_i < ROWS_PER_THREAD; ++row_i) { + float pre_curr_sum_ = curr_sum_[row_i]; + curr_sum_[row_i] = + __expf(prev_max_[row_i] - curr_max_[row_i]) * curr_sum_[row_i] + prev_sum_[row_i]; + prev_sum_[row_i] = pre_curr_sum_; + } + } + + // updater scales + float curr_max_[ROWS_PER_THREAD] = {-HUGE_VALF}; + float curr_sum_[ROWS_PER_THREAD] = {0}; + float prev_max_[ROWS_PER_THREAD] = {-HUGE_VALF}; + float prev_sum_[ROWS_PER_THREAD] = {0}; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_from_size_in_bytes { + using Type = Fragment(sizeof(Data_type_))>; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_from_size_in_bytes { + using Type = Fragment; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void clear(Fragment (&frag)[M][N]) { +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ni = 0; ni < N; ++ni) { + frag[mi][ni].clear(); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Clear_accumulator {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Clear_accumulator { + template + static inline __device__ void apply(Acc (&acc)[M][N], bool = false) { + fmha::clear(acc); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Clear_accumulator { + template + static inline __device__ void apply(Acc (&acc)[M][N], bool = false) { + fmha::clear(acc); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Clear_accumulator { + template + static inline __device__ void apply(Acc (&acc)[M][N], bool = false) { + fmha::clear(acc); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Clear_accumulator { + template + static inline __device__ void apply(Acc (&acc)[M][N], bool enable_i2f_trick = true) { +#if defined(USE_I2F_EMULATION_TRICK) + if (enable_i2f_trick) { +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ni = 0; ni < N; ++ni) { +#pragma unroll + for (int ii = 0; ii < Acc::NUM_REGS; ++ii) { + acc[mi][ni].reg(ii) = uint32_t(FP32_I2F_MAGIC_NUMBER_HEX) / WARPS_K; + } + } + } + } else +#endif // defined(USE_I2F_EMULATION_TRICK) + { + fmha::clear(acc); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/gemm.h b/csrc/fmha_v2/fmha/gemm.h new file mode 100644 index 0000000000..e1422e4f6e --- /dev/null +++ b/csrc/fmha_v2/fmha/gemm.h @@ -0,0 +1,35 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include +#include + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void gemm(Acc (&acc)[M][N], A const (&a)[M], B const (&b)[N]) { +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ni = 0; ni < N; ++ni) { + acc[mi][ni].mma(a[mi], b[ni]); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/gmem_tile_o.h b/csrc/fmha_v2/fmha/gmem_tile_o.h new file mode 100644 index 0000000000..c3177dc219 --- /dev/null +++ b/csrc/fmha_v2/fmha/gmem_tile_o.h @@ -0,0 +1,465 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include +#include + +namespace fmha { +namespace v1 { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// H M M A +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hmma_gmem_tile_o { + // The mma tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The size of each element. + enum { BYTES_PER_ELEMENT = 2 }; + + // The size of a row in bytes. + enum { BYTES_PER_ROW = Cta_tile::N * BYTES_PER_ELEMENT }; + + // The size of each STG. + enum { BYTES_PER_STG = 16 }; + + // The number of threads to store a "row" of the matrix. + enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_STG }; + + // The number of "rows" stored per STG. + enum { ROWS_PER_STG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + + // The number of "rows" stored per iteration of the loop. + enum { ROWS = Cta_tile::M }; + + // We want at least one output per thread (if possible). + enum { ROWS_PER_LOOP_ = ROWS <= 64 ? ROWS : (int)Min::VALUE }; + + // We also want to have "complete" MMAs. + enum { ROWS_PER_LOOP = Max::VALUE }; + + // The number of outer loop for the stores. + enum { LOOPS = fmha::Div_up::VALUE }; + + // DEBUG. + static_assert(ROWS % ROWS_PER_LOOP == 0, ""); + // END OF DEBUG. + + // Make sure the math is correct. + static_assert(ROWS_PER_LOOP >= (int)Mma_tile::M_PER_MMA_PER_CTA, ""); + + // Do we have to guard against partial writes/reads. + enum { HAS_INCOMPLETE_STG = Cta_tile::M % ROWS_PER_STG != 0 }; + + // The number of STGs needed to store a chunk of the Q matrix. + enum { STGS_PER_LOOP = fmha::Div_up::VALUE }; + + // The number of STGs needed to store a chunk of the Q matrix in total. + enum { STGS = STGS_PER_LOOP * LOOPS }; + + // Ctor. + template + inline __device__ Hmma_gmem_tile_o(Params const& params, Block_info const& binfo, int tidx, + int cta_row_offset = 0) + : params_o_stride_in_bytes_(params.o_stride_in_bytes), + o_ptr_(reinterpret_cast(params.o_ptr)) { + // Compute the position in the sequence (within the CTA for the moment). + int row = tidx / THREADS_PER_ROW; + // Compute the position of the thread in the row. + int col = tidx % THREADS_PER_ROW; + + // Is that thread active on the last STG? + if (HAS_INCOMPLETE_STG) { + is_active_for_last_stg_ = row + (STGS - 1) * ROWS_PER_STG < Cta_tile::M; + } + + // Account for the CTA-wide row offset (no loop mode). + row += cta_row_offset; + + // The row offset in the batched GEMM. + int64_t row_offset = (int64_t)row * params.o_stride_in_bytes; + // Take the batch/head offset into account. + row_offset += (int64_t)binfo.bidx * BYTES_PER_ROW; + // Assemble the final pointer. + o_ptr_ += row_offset + col * BYTES_PER_STG; + } + + // Load data from global memory. + inline __device__ void load(uint4 (&dst)[STGS_PER_LOOP], int mi) { + if (blockIdx.x == 0) { +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + dst[ii] = make_uint4(0u, 0u, 0u, 0u); + } + } else { +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + int jj = mi * STGS_PER_LOOP + ii; + if (!HAS_INCOMPLETE_STG || (jj < STGS - 1 || is_active_for_last_stg_)) { + fmha::ldg(dst[ii], o_ptr_ + jj * ROWS_PER_STG * params_o_stride_in_bytes_); + } + } + } + } + + // Store data to global memory. + inline __device__ void store(uint4 const (&src)[STGS_PER_LOOP], int mi) { +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + int jj = mi * STGS_PER_LOOP + ii; + if (!HAS_INCOMPLETE_STG || (jj < STGS - 1 || is_active_for_last_stg_)) { + fmha::stg(o_ptr_ + jj * ROWS_PER_STG * params_o_stride_in_bytes_, src[ii]); + } + } + } + + // Store data to global memory. + inline __device__ void store(uint4 const (&src)[STGS_PER_LOOP], uint4 const (&old)[STGS_PER_LOOP], + int mi) { + uint4 tmp[STGS_PER_LOOP]; +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + tmp[ii].x = fmha::hadd2(src[ii].x, old[ii].x); + tmp[ii].y = fmha::hadd2(src[ii].y, old[ii].y); + tmp[ii].z = fmha::hadd2(src[ii].z, old[ii].z); + tmp[ii].w = fmha::hadd2(src[ii].w, old[ii].w); + } + this->store(tmp, mi); + } + + // Move the pointer to the next location. + inline __device__ void move() { o_ptr_ += (int64_t)ROWS * params_o_stride_in_bytes_; } + + // The stride between rows for the QKV matrice. + int64_t const params_o_stride_in_bytes_; + // The pointer. + char* o_ptr_; + // Is the thread active for the last STG? + int is_active_for_last_stg_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o + : public Hmma_gmem_tile_o { + // The traits. + using Traits = fmha::Volta_hmma_fp16_16x16x16_traits; + // The base class. + using Base = Hmma_gmem_tile_o; + + // Ctor. + template + inline __device__ Gmem_tile_o(Params const& params, Block_info const& binfo, int tidx, + int cta_row_offset = 0) + : Base(params, binfo, tidx, cta_row_offset) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o + : public Hmma_gmem_tile_o { + // The traits. + using Traits = fmha::Turing_hmma_fp16_traits; + // The base class. + using Base = Hmma_gmem_tile_o; + + // Ctor. + template + inline __device__ Gmem_tile_o(Params const& params, Block_info const& binfo, int tidx, + int cta_row_offset = 0) + : Base(params, binfo, tidx, cta_row_offset) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o + : public Hmma_gmem_tile_o { + // The traits. + using Traits = fmha::Ampere_hmma_fp16_traits; + // The base class. + using Base = Hmma_gmem_tile_o; + + // Ctor. + template + inline __device__ Gmem_tile_o(Params const& params, Block_info const& binfo, int tidx, + int cta_row_offset = 0) + : Base(params, binfo, tidx, cta_row_offset) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// I M M A +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Imma_gmem_tile_o { + // The mma tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The size of each element. + enum { BYTES_PER_ELEMENT = 1 }; + + // The size of a row in bytes. + enum { BYTES_PER_ROW = Cta_tile::N * BYTES_PER_ELEMENT }; + + // The size of each STG. + enum { BYTES_PER_STG = 4 }; + + // The number of threads to store a "row" of the matrix. + enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_STG }; + + // The number of "rows" stored per STG. + enum { ROWS_PER_STG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + + // The number of "rows" stored per iteration of the loop. The output of 1 MMA. + enum { ROWS = Cta_tile::M }; + + // We want at least one output per thread (if possible). + enum { ROWS_PER_LOOP_ = ROWS <= 64 ? ROWS : (int)Min::VALUE }; + + // We also want to have "complete" MMAs. + enum { ROWS_PER_LOOP = Max::VALUE }; + + // The number of outer loop for the stores. + enum { LOOPS = fmha::Div_up::VALUE }; + + // DEBUG. + static_assert(ROWS % ROWS_PER_LOOP == 0, ""); + // END OF DEBUG. + + // Make sure the math is correct. + static_assert(ROWS_PER_LOOP >= (int)Mma_tile::M_PER_MMA_PER_CTA, ""); + + // Do we have to guard against partial writes/reads (last STG). + enum { HAS_INCOMPLETE_STG = Cta_tile::M % ROWS_PER_STG != 0 }; + + // The number of STGs needed to store a chunk of the Q matrix. + enum { STGS_PER_LOOP = fmha::Div_up::VALUE }; + + // The number of STGs needed to store a chunk of the Q matrix in total. + enum { STGS = STGS_PER_LOOP * LOOPS }; + + // Are all threads active? + enum { ALL_THREADS_ACTIVE = ROWS_PER_STG <= ROWS_PER_LOOP }; + + // The number of active threads. + enum { ACTIVE_THREADS_ = Cta_tile::THREADS_PER_CTA * ROWS_PER_LOOP / ROWS_PER_STG }; + + // The number of active threads. + enum { ACTIVE_THREADS = ALL_THREADS_ACTIVE ? Cta_tile::THREADS_PER_CTA : ACTIVE_THREADS_ }; + + // Ctor. + template + inline __device__ Imma_gmem_tile_o(Params const& params, int bidx, int tidx, int cta_row_offset) + : params_o_stride_in_bytes_(params.o_stride_in_bytes), + params_scale_bmm2_(params.scale_bmm2), + params_enable_i2f_trick_(params.enable_i2f_trick), + o_ptr_(reinterpret_cast(params.o_ptr)) +#if USE_DEMO_BERT_PARAMS + , + o_scratch_ptr_(nullptr) { +#else + , + o_scratch_ptr_(reinterpret_cast(params.o_scratch_ptr)) { +#endif + + // Compute the position in the sequence (within the CTA for the moment). + int row = tidx / THREADS_PER_ROW; + // Compute the position of the thread in the row. + int col = tidx % THREADS_PER_ROW; + + // Is it an active thread? + is_active_ = ALL_THREADS_ACTIVE || row < ROWS_PER_LOOP; + + // Is that thread active on the last STG? + if (HAS_INCOMPLETE_STG) { + is_active_for_last_stg_ = row + (STGS - 1) * ROWS_PER_STG < Cta_tile::M; + } + + // Update the row. + row += cta_row_offset; + + // The row offset in the batched GEMM. + int64_t row_offset = (int64_t)row * params.o_stride_in_bytes; + // Take the batch/head offset into account. + row_offset += (int64_t)bidx * BYTES_PER_ROW; + // Assemble the final pointers. + o_ptr_ += row_offset + col * BYTES_PER_STG; + + // For the scratch space, the pointer has int32 type so it accounts for the *4 factor. + o_scratch_ptr_ += blockIdx.y * STGS_PER_LOOP * ACTIVE_THREADS + tidx; + } + + // Load data from global memory. + inline __device__ void load(uint4 (&dst)[STGS_PER_LOOP], int mi) { + if (blockIdx.x == 0) { +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + dst[ii] = make_uint4(0u, 0u, 0u, 0u); + } + } else if (ALL_THREADS_ACTIVE || is_active_) { +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + fmha::ldg(dst[ii], o_scratch_ptr_ + ii * ACTIVE_THREADS); + } + } + } + + // Store data to global memory. + inline __device__ void store(uint4 const (&src)[STGS_PER_LOOP], int mi) { + // The scale. + float const& scale = reinterpret_cast(params_scale_bmm2_); +// Iterate over the different STGs. +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + // The accumulators are in int32_t. + int4 const& val = reinterpret_cast(src[ii]); + + // Extract the floats and scale. + float f0, f1, f2, f3; +#if defined(USE_I2F_EMULATION_TRICK) + if (params_enable_i2f_trick_) { + f0 = reinterpret_cast(val.x) - FP32_I2F_MAGIC_NUMBER; + f1 = reinterpret_cast(val.y) - FP32_I2F_MAGIC_NUMBER; + f2 = reinterpret_cast(val.z) - FP32_I2F_MAGIC_NUMBER; + f3 = reinterpret_cast(val.w) - FP32_I2F_MAGIC_NUMBER; + } else +#endif // defined(USE_I2F_EMULATION_TRICK) + { + f0 = static_cast(val.x); + f1 = static_cast(val.y); + f2 = static_cast(val.z); + f3 = static_cast(val.w); + } + + // Apply the scaling. + f0 *= scale; + f1 *= scale; + f2 *= scale; + f3 *= scale; + + // Convert the 4 floats to char4. + uint32_t dst = float4_to_char4(f0, f1, f2, f3); + + // Store the result. + int jj = mi * STGS_PER_LOOP + ii; + if (!HAS_INCOMPLETE_STG || (jj < STGS - 1 || is_active_for_last_stg_)) { + fmha::stg(o_ptr_ + jj * ROWS_PER_STG * params_o_stride_in_bytes_, dst); + } + } + } + + // Store data to global memory. + inline __device__ void store(uint4 const (&src)[STGS_PER_LOOP], uint4 const (&old)[STGS_PER_LOOP], + int mi) { + // Do the reduction. + uint4 tmp[STGS_PER_LOOP]; +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + int4 const& src_ii = reinterpret_cast(src[ii]); + int4 const& old_ii = reinterpret_cast(old[ii]); + + int32_t x = src_ii.x + old_ii.x; + int32_t y = src_ii.y + old_ii.y; + int32_t z = src_ii.z + old_ii.z; + int32_t w = src_ii.w + old_ii.w; + + tmp[ii].x = reinterpret_cast(x); + tmp[ii].y = reinterpret_cast(y); + tmp[ii].z = reinterpret_cast(z); + tmp[ii].w = reinterpret_cast(w); + } + + // The last CTA stores INT8 values to the final location. + if (blockIdx.x == CTAS_PER_HEAD - 1) { + this->store(tmp, mi); + + // Other CTAs store INT32 values to the scratch space. + } else if (ALL_THREADS_ACTIVE || is_active_) { +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + fmha::stg(o_scratch_ptr_ + ii * ACTIVE_THREADS, tmp[ii]); + } + } + } + + // Move the pointer. + inline __device__ void move() { o_ptr_ += (int64_t)ROWS * params_o_stride_in_bytes_; } + + // The stride between rows for the QKV matrice. + int64_t const params_o_stride_in_bytes_; + // The scaling factor to convert to int8. + uint32_t const params_scale_bmm2_; + // Do we enable the i2f trick? + bool const params_enable_i2f_trick_; + // The pointer. + char* o_ptr_; + // The scratch pointer for 32-bit reductions. + int32_t* o_scratch_ptr_; + + // Is it an active thread? When ROWS_PER_STG > ROWS_PER_LOOP, some threads do not store. + int is_active_, is_active_for_last_stg_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o + : public Imma_gmem_tile_o { + // The traits class. + using Traits = fmha::Turing_imma_int8_int32_traits; + // The base class. + using Base = Imma_gmem_tile_o; + + // Ctor. + template + inline __device__ Gmem_tile_o(Params const& params, Block_info const& block_info, int tidx, + int cta_row_offset = 0) + : Base(params, block_info.bidx, tidx, cta_row_offset) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o + : public Imma_gmem_tile_o { + // The traits class. + using Traits = fmha::Ampere_imma_int8_int32_traits; + // The base class. + using Base = Imma_gmem_tile_o; + + // Ctor. + template + inline __device__ Gmem_tile_o(Params const& params, Block_info const& block_info, int tidx, + int cta_row_offset = 0) + : Base(params, block_info.bidx, tidx, cta_row_offset) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace v1 +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/gmem_tile_o_packed.h b/csrc/fmha_v2/fmha/gmem_tile_o_packed.h new file mode 100644 index 0000000000..dc13b37f19 --- /dev/null +++ b/csrc/fmha_v2/fmha/gmem_tile_o_packed.h @@ -0,0 +1,1349 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include +#include +#include + +namespace fmha { +namespace v2 { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// H M M A +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hmma_gmem_tile_o { + // The mma tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The size of each element. + enum { BYTES_PER_ELEMENT = BYTES_PER_ELEMENT_ }; + + // The size of a row in bytes. + enum { BYTES_PER_ROW = Cta_tile::N * BYTES_PER_ELEMENT }; + + // The valid size of a row in bytes. + // Note: cross-attention kernels rely on head dim from runtime instead of from compile-time. + // This approach deviates from self-attention kernels. To explore a unified approach. + // enum { VALID_BYTES_PER_ROW = Cta_tile::VALID_N * BYTES_PER_ELEMENT }; + + // The size of each STG. + enum { BYTES_PER_STG = BYTES_PER_STG_ }; + + // The number of threads to store a "row" of the matrix. + enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_STG }; + + // The number of "rows" stored per STG. + enum { ROWS_PER_STG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + + // The number of "rows" stored per iteration of the loop. The output of 1 MMA. + enum { ROWS = Cta_tile::M }; + + // We want at least one output per thread (if possible). + enum { ROWS_PER_LOOP_ = ROWS <= 64 ? ROWS : (int)Min::VALUE }; + + // We also want to have "complete" MMAs. + enum { ROWS_PER_LOOP = Max::VALUE }; + + // The number of outer loop for the stores. + enum { LOOPS = fmha::Div_up::VALUE }; + + // DEBUG. + static_assert(ROWS % ROWS_PER_LOOP == 0, ""); + // END OF DEBUG. + + // Make sure the math is correct. + static_assert(ROWS_PER_LOOP >= (int)Mma_tile::M_PER_MMA_PER_CTA, ""); + + // Do we have to guard against partial writes/reads. + enum { HAS_INCOMPLETE_STG = Cta_tile::M % ROWS_PER_STG != 0 }; + + // The number of STGs needed to store a chunk of the Q matrix. + enum { STGS_PER_LOOP = fmha::Div_up::VALUE }; + + // The number of STGs needed to store a chunk of the Q matrix in total. + enum { STGS = STGS_PER_LOOP * LOOPS }; + + // Ctor. + template + inline __device__ Hmma_gmem_tile_o(Params const& params, Block_info const& binfo, int tidx, + int cta_row_offset, int cta_col_offset_in_bytes = 0) + : params_o_stride_in_bytes_(params.o_stride_in_bytes), + actual_seqlen_(binfo.actual_q_seqlen), + o_ptr_(reinterpret_cast(params.o_ptr)) { + // Compute the position in the sequence (within the CTA for the moment). + int row = tidx / THREADS_PER_ROW; + // Compute the position of the thread in the row. + int col = tidx % THREADS_PER_ROW; + + // Is that thread active on the last STG? + if (HAS_INCOMPLETE_STG) { + is_active_for_last_stg_ = row + (STGS - 1) * ROWS_PER_STG < Cta_tile::M; + } + + // Store the row/col to update the predicates in load. + row_ = cta_row_offset + row; + col_in_bytes_ = cta_col_offset_in_bytes + col * BYTES_PER_STG; + init_row_ = row_; + + // The row offset in the batched GEMM. + int64_t row_offset = (int64_t)row_ * params.o_stride_in_bytes; + // The amount of bytes per row without padding. + int const valid_bytes_per_row = params.dv * BYTES_PER_ELEMENT; + // Take the batch/head offset into account. TODO: Fix me! + // + // row_offset += binfo.bidx * VALID_BYTES_PER_ROW; + // + row_offset += binfo.bidx * valid_bytes_per_row; + + // Assemble the final pointer. + o_ptr_ += row_offset + col_in_bytes_; + init_o_ptr_ = o_ptr_; + + // Do not store if the thread is in the padded area + active_ = col_in_bytes_ < valid_bytes_per_row; + } + + // Load data from global memory. + inline __device__ void load(uint4 (&dst)[STGS_PER_LOOP], int mi) { + if (blockIdx.x == 0) { +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + dst[ii] = make_uint4(0u, 0u, 0u, 0u); + } + } else { +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + int jj = mi * STGS_PER_LOOP + ii; + if (row_ + jj * ROWS_PER_STG >= actual_seqlen_) { + break; + } + if (active_ && (!HAS_INCOMPLETE_STG || (jj < STGS - 1 || is_active_for_last_stg_))) { + fmha::ldg(dst[ii], o_ptr_ + jj * ROWS_PER_STG * params_o_stride_in_bytes_); + } + } + } + } + + // Store data to global memory. + inline __device__ void store(uint4 const (&src)[STGS_PER_LOOP], int mi) { +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + int jj = mi * STGS_PER_LOOP + ii; + if (row_ + jj * ROWS_PER_STG >= actual_seqlen_) { + break; + } + if (active_ && (!HAS_INCOMPLETE_STG || (jj < STGS - 1 || is_active_for_last_stg_))) { + fmha::stg(o_ptr_ + jj * ROWS_PER_STG * params_o_stride_in_bytes_, src[ii]); + } + } + } + + // Store data to global memory. + inline __device__ void store(uint4 const (&src)[STGS_PER_LOOP], uint4 const (&old)[STGS_PER_LOOP], + int mi) { + uint4 tmp[STGS_PER_LOOP]; +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + tmp[ii].x = fmha::hadd2(src[ii].x, old[ii].x); + tmp[ii].y = fmha::hadd2(src[ii].y, old[ii].y); + tmp[ii].z = fmha::hadd2(src[ii].z, old[ii].z); + tmp[ii].w = fmha::hadd2(src[ii].w, old[ii].w); + } + this->store(tmp, mi); + } + + // Move the pointer to the next location. + inline __device__ void move(int const steps = 1) { + row_ += ROWS * steps; + o_ptr_ += (int64_t)ROWS * params_o_stride_in_bytes_ * steps; + } + + inline __device__ void move_to(int const step) { + row_ = init_row_ + ROWS * step; + o_ptr_ = init_o_ptr_ + (int64_t)ROWS * params_o_stride_in_bytes_ * step; + } + + // The stride between rows for the QKV matrice. + int64_t params_o_stride_in_bytes_; + // The pointer. + char* o_ptr_; + char* init_o_ptr_; + // Is the thread active for the last STG? + int is_active_for_last_stg_; + + // The row loaded by this thread. + int row_, col_in_bytes_; + int init_row_; + // The length of the sequence loaded by that CTA. + int actual_seqlen_; + // Is that thread active when it comes to loading data? + int active_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o + : public Hmma_gmem_tile_o { + // The traits. + using Traits = fmha::Volta_hmma_fp16_16x16x16_traits; + // The base class. + using Base = Hmma_gmem_tile_o; + + // Ctor. + template + inline __device__ Gmem_tile_o(Params const& params, Block_info const& block_info, int tidx, + int cta_row_offset = 0) + : Base(params, block_info, tidx, cta_row_offset) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o + : public Hmma_gmem_tile_o { + // The traits. + using Traits = fmha::Turing_hmma_fp16_traits; + // The base class. + using Base = Hmma_gmem_tile_o; + + // Ctor. + template + inline __device__ Gmem_tile_o(Params const& params, Block_info const& block_info, int tidx, + int cta_row_offset = 0, int cta_col_offset_in_bytes = 0) + : Base(params, block_info, tidx, cta_row_offset, cta_col_offset_in_bytes) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o + : public Hmma_gmem_tile_o { + // The traits. + using Traits = fmha::Ampere_hmma_fp16_traits; + // The base class. + using Base = Hmma_gmem_tile_o; + + // Ctor. + template + inline __device__ Gmem_tile_o(Params const& params, Block_info const& block_info, int tidx, + int cta_row_offset = 0, int cta_col_offset_in_bytes = 0) + : Base(params, block_info, tidx, cta_row_offset, cta_col_offset_in_bytes) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o + : public Hmma_gmem_tile_o { + // The traits. + using Traits = fmha::Ampere_hmma_bf16_bf16_traits; + // The base class. + using Base = Hmma_gmem_tile_o; + + // Ctor. + template + inline __device__ Gmem_tile_o(Params const& params, Block_info const& block_info, int tidx, + int cta_row_offset = 0, int cta_col_offset_in_bytes = 0) + : Base(params, block_info, tidx, cta_row_offset, cta_col_offset_in_bytes) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o + : public Hmma_gmem_tile_o { + // The traits. + using Traits = fmha::Ampere_hmma_fp32_traits; + // The base class. + using Base = Hmma_gmem_tile_o; + + // The epilogue data type + using Epilogue_type = typename Traits::Epilogue_type; + + // DEBUG. + static_assert((Base::THREADS_PER_ROW == 16 || Base::THREADS_PER_ROW == 32 || + Base::THREADS_PER_ROW == 64 || Base::THREADS_PER_ROW == 128) && + Base::BYTES_PER_STG == 8, + ""); + + // END OF DEBUG. + + enum { STGS_PER_LOOP = Base::STGS_PER_LOOP }; + + enum { ROWS_PER_STG = Base::ROWS_PER_STG }; + + enum { STGS = Base::STGS }; + + enum { HAS_INCOMPLETE_STG = Base::HAS_INCOMPLETE_STG }; + + // Ctor. + template + inline __device__ Gmem_tile_o(Params const& params, Block_info const& block_info, int tidx, + int cta_row_offset = 0, int cta_col_offset_in_bytes = 0) + : Base(params, block_info, tidx, cta_row_offset, cta_col_offset_in_bytes) {} + + // Load data from global memory. + inline __device__ void load(uint4 const (&dst)[STGS_PER_LOOP], int mi) { + static_assert(CTAS_PER_HEAD == 1, "Not implemented"); + } + + // Store data to global memory. + inline __device__ void store(uint4 const (&src)[STGS_PER_LOOP], int mi) { +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + int jj = mi * STGS_PER_LOOP + ii; + if (this->row_ + jj * ROWS_PER_STG >= this->actual_seqlen_) { + break; + } + + float x = reinterpret_cast(src[ii].x); + float y = reinterpret_cast(src[ii].y); + float z = reinterpret_cast(src[ii].z); + float w = reinterpret_cast(src[ii].w); + + uint2 out = float4_to_16bit_x4(x, y, z, w); + if (this->active_ && + (!HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_))) { + fmha::stg(this->o_ptr_ + jj * ROWS_PER_STG * this->params_o_stride_in_bytes_, out); + } + } + } + + // Store data to global memory. + inline __device__ void store(uint4 const (&src)[STGS_PER_LOOP], uint4 const (&old)[STGS_PER_LOOP], + int mi) { + static_assert(CTAS_PER_HEAD == 1, "Not implemented"); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o + : public Hmma_gmem_tile_o { + // The traits. + using Traits = fmha::Ampere_hmma_bf16_traits; + // The base class. + using Base = Hmma_gmem_tile_o; + + // The epilogue data type + using Epilogue_type = typename Traits::Epilogue_type; + + // DEBUG. + static_assert((Base::THREADS_PER_ROW == 16 || Base::THREADS_PER_ROW == 32 || + Base::THREADS_PER_ROW == 64 || Base::THREADS_PER_ROW == 128) && + Base::BYTES_PER_STG == 8, + ""); + + // END OF DEBUG. + + enum { STGS_PER_LOOP = Base::STGS_PER_LOOP }; + + enum { ROWS_PER_STG = Base::ROWS_PER_STG }; + + enum { STGS = Base::STGS }; + + enum { HAS_INCOMPLETE_STG = Base::HAS_INCOMPLETE_STG }; + + // Ctor. + template + inline __device__ Gmem_tile_o(Params const& params, Block_info const& block_info, int tidx, + int cta_row_offset = 0, int cta_col_offset_in_bytes = 0) + : Base(params, block_info, tidx, cta_row_offset, cta_col_offset_in_bytes) {} + + // Load data from global memory. + inline __device__ void load(uint4 const (&dst)[STGS_PER_LOOP], int mi) { + static_assert(CTAS_PER_HEAD == 1, "Not implemented"); + } + + // Store data to global memory. + inline __device__ void store(uint4 const (&src)[STGS_PER_LOOP], int mi) { +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + int jj = mi * STGS_PER_LOOP + ii; + if (this->row_ + jj * ROWS_PER_STG >= this->actual_seqlen_) { + break; + } + + float x = reinterpret_cast(src[ii].x); + float y = reinterpret_cast(src[ii].y); + float z = reinterpret_cast(src[ii].z); + float w = reinterpret_cast(src[ii].w); + + uint2 out = float4_to_16bit_x4(x, y, z, w); + if (this->active_ && + (!HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_))) { + fmha::stg(this->o_ptr_ + jj * ROWS_PER_STG * this->params_o_stride_in_bytes_, out); + } + } + } + + // Store data to global memory. + inline __device__ void store(uint4 const (&src)[STGS_PER_LOOP], uint4 const (&old)[STGS_PER_LOOP], + int mi) { + static_assert(CTAS_PER_HEAD == 1, "Not implemented"); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// I M M A +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint32_t quantize(int4 const val, float const scale, + bool const params_enable_i2f_trick) { + // Extract the floats and scale. + float f0, f1, f2, f3; +#if defined(USE_I2F_EMULATION_TRICK) + if (params_enable_i2f_trick) { + f0 = reinterpret_cast(val.x) - FP32_I2F_MAGIC_NUMBER; + f1 = reinterpret_cast(val.y) - FP32_I2F_MAGIC_NUMBER; + f2 = reinterpret_cast(val.z) - FP32_I2F_MAGIC_NUMBER; + f3 = reinterpret_cast(val.w) - FP32_I2F_MAGIC_NUMBER; + } else +#endif // defined(USE_I2F_EMULATION_TRICK) + { + f0 = static_cast(val.x); + f1 = static_cast(val.y); + f2 = static_cast(val.z); + f3 = static_cast(val.w); + } + + // Apply the scaling. + f0 *= scale; + f1 *= scale; + f2 *= scale; + f3 *= scale; + + // Convert the 4 floats to char4. + uint32_t dst = float4_to_char4(f0, f1, f2, f3); + + return dst; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Helpers to pack 4 registers representing a Src_type into a destination register with 4 8bit +// values representing Dst_type. Scale factor is assumed to be always FP32 for 32-bit accumulators. +template +struct Acc_packer {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Signed INT32 => INT8. +template <> +struct Acc_packer { + template + static inline __device__ uint32_t run(This const* this_, uint4 const& src_regs) { + float const& scale = reinterpret_cast(this_->params_scale_bmm2_); + // The accumulators are in int32_t. + int4 const& val = reinterpret_cast(src_regs); + + // Quantize... + uint32_t dst = quantize(val, scale, this_->params_enable_i2f_trick_); + return dst; + } +}; + +template <> +struct Acc_packer { + template + static inline __device__ uint32_t run(This const* this_, uint4 const& src_regs) { + // The accumulators are in int32_t. + int4 const& val = reinterpret_cast(src_regs); + + // Quantize... + uint32_t dst = quantize(val, 1.0f, this_->params_enable_i2f_trick_); + return dst; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// FP32 => FP8. +template <> +struct Acc_packer { + template + static inline __device__ uint32_t run(This const* this_, uint4 const& src_regs) { + float const scale = reinterpret_cast(this_->params_scale_bmm2_); + + float4 const& val = reinterpret_cast(src_regs); + + uint32_t dst = + fmha::float4_to_e4m3x4(val.x * scale, val.y * scale, val.z * scale, val.w * scale); + return dst; + } + + template + static inline __device__ uint16_t run(This const* this_, uint2 const& src_regs) { + float const& scale = reinterpret_cast(this_->params_scale_bmm2_); + + float2 const& val = reinterpret_cast(src_regs); + + uint16_t dst = fmha::float2_to_e4m3x2(val.x * scale, val.y * scale); + return dst; + } +}; + +// FP32 => FP8. +template <> +struct Acc_packer { + template + static inline __device__ uint32_t run(This const* this_, uint4 const& src_regs) { + float4 const& val = reinterpret_cast(src_regs); + + uint32_t dst = fmha::float4_to_e4m3x4(val.x, val.y, val.z, val.w); + return dst; + } + + template + static inline __device__ uint16_t run(This const* this_, uint2 const& src_regs) { + float2 const& val = reinterpret_cast(src_regs); + + uint16_t dst = fmha::float2_to_e4m3x2(val.x, val.y); + return dst; + } +}; + +// FP16 => FP8. +template <> +struct Acc_packer { + template + static inline __device__ uint2 run(This const* this_, uint4 const& src_regs) { + uint2 dst; + dst.x = fmha::half4_to_e4m3x4(fmha::hmul2(src_regs.x, this_->params_scale_bmm2_), + fmha::hmul2(src_regs.y, this_->params_scale_bmm2_)); + dst.y = fmha::half4_to_e4m3x4(fmha::hmul2(src_regs.z, this_->params_scale_bmm2_), + fmha::hmul2(src_regs.w, this_->params_scale_bmm2_)); + + return dst; + } +}; + +// FP16 => FP8. +template <> +struct Acc_packer { + template + static inline __device__ uint2 run(This const* this_, uint4 const& src_regs) { + uint2 dst; + dst.x = fmha::half4_to_e4m3x4(src_regs.x, src_regs.y); + dst.y = fmha::half4_to_e4m3x4(src_regs.z, src_regs.w); + + return dst; + } +}; + +template <> +struct Acc_packer { + template + static inline __device__ uint32_t run(This const* this_, uint4 const& src_regs) { + float const& scale = reinterpret_cast(this_->params_scale_bmm2_); + + float4 const& val = reinterpret_cast(src_regs); + + uint32_t dst = + fmha::float4_to_e5m2x4(val.x * scale, val.y * scale, val.z * scale, val.w * scale); + return dst; + } +}; + +template <> +struct Acc_packer { + template + static inline __device__ uint32_t run(This const* this_, uint4 const& src_regs) { + float4 const& val = reinterpret_cast(src_regs); + + uint32_t dst = fmha::float4_to_e5m2x4(val.x, val.y, val.z, val.w); + return dst; + } +}; + +template <> +struct Acc_packer { + template + static inline __device__ uint2 run(This const* this_, uint4 const& src_regs) { + float4 const& val = reinterpret_cast(src_regs); + + uint2 dst = fmha::float4_to_half4(val.x, val.y, val.z, val.w); + return dst; + } +}; + +template <> +struct Acc_packer { + template + static inline __device__ uint2 run(This const* this_, uint4 const& src_regs) { + float const& scale = reinterpret_cast(this_->params_scale_bmm2_); + + float4 const& val = reinterpret_cast(src_regs); + + uint2 dst = fmha::float4_to_half4(val.x * scale, val.y * scale, val.z * scale, val.w * scale); + return dst; + } +}; + +template <> +struct Acc_packer { + template + static inline __device__ uint2 run(This const* this_, uint4 const& src_regs) { + float4 const& val = reinterpret_cast(src_regs); + + uint2 dst = fmha::float4_to_16bit_x4(val.x, val.y, val.z, val.w); + return dst; + } +}; + +template <> +struct Acc_packer { + template + static inline __device__ uint2 run(This const* this_, uint4 const& src_regs) { + float const& scale = reinterpret_cast(this_->params_scale_bmm2_); + + float4 const& val = reinterpret_cast(src_regs); + + uint2 dst = fmha::float4_to_16bit_x4(val.x * scale, val.y * scale, val.z * scale, + val.w * scale); + return dst; + } +}; + +// support both 32 bit accumulationi and 16 bit accumulation (imma and qmma) +template +struct Gmem_tile_o_8bit { + // static_assert(sizeof(typename Traits::Accumulator_type) == 4); + static_assert(sizeof(typename Traits::C_type) == 1); + + // The mma tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The size of each element. + enum { BYTES_PER_ELEMENT = 1 }; + + // The size of a row in bytes. + enum { BYTES_PER_ROW = Cta_tile::N * BYTES_PER_ELEMENT }; + + // The valid size of a row in bytes. + enum { VALID_BYTES_PER_ROW = Cta_tile::VALID_N * BYTES_PER_ELEMENT }; + + // The size of each STG (16B --> 8bit elements). + enum { BYTES_PER_STG = fmha::Div_up<16, sizeof(typename Traits::Accumulator_type)>::VALUE }; + + // The STG packed data type + using Stg_packed_type = typename Uint_from_size_in_bytes::Type; + + // The number of threads to store a "row" of the matrix. + enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_STG }; + + // The number of "rows" stored per STG. + enum { ROWS_PER_STG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + + // The number of "rows" stored per iteration of the loop. The output of 1 MMA. + enum { ROWS = Cta_tile::M }; + + // We want at least one output per thread (if possible). + enum { ROWS_PER_LOOP_ = ROWS <= 64 ? ROWS : (int)Min::VALUE }; + + // We also want to have "complete" MMAs. + enum { ROWS_PER_LOOP = Max::VALUE }; + + // The number of outer loop for the stores. + enum { LOOPS = fmha::Div_up::VALUE }; + + // DEBUG. + static_assert(ROWS % ROWS_PER_LOOP == 0, ""); + + // Make sure the math is correct. + static_assert(ROWS_PER_LOOP >= (int)Mma_tile::M_PER_MMA_PER_CTA, ""); + + // Do we have to guard against partial writes/reads. + enum { HAS_INCOMPLETE_STG = Cta_tile::M % ROWS_PER_STG != 0 }; + + // The number of STGs needed to store a chunk of the Q matrix. + enum { STGS_PER_LOOP = fmha::Div_up::VALUE }; + + // The number of STGs needed to store a chunk of the Q matrix in total. + enum { STGS = STGS_PER_LOOP * LOOPS }; + +#if 0 + // The number of "rows" stored per iteration of the loop. The output of 1 MMA. + enum { ROWS = Cta_tile::M }; + // The number of "rows" stored per iteration of the loop. The output of 1 MMA. + enum { ROWS_PER_LOOP = Mma_tile::M_PER_MMA_PER_CTA }; + // The number of outer loop for the stores. + enum { LOOPS = ROWS / ROWS_PER_LOOP }; + + // Make sure the math is correct. + static_assert(LOOPS == (int)Mma_tile::MMAS_M, ""); + + // The number of "rows" stored per STG -- for it to be the number of rows per MMA instruction. + enum { ROWS_PER_STG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + // The number of STGs needed to store a chunk of the Q matrix. + enum { STGS_PER_LOOP = fmha::Div_up::VALUE }; +#endif + + // Are all threads active? + enum { ALL_THREADS_ACTIVE = ROWS_PER_STG <= ROWS_PER_LOOP }; + + // The number of active threads. + enum { ACTIVE_THREADS_ = Cta_tile::THREADS_PER_CTA * ROWS_PER_LOOP / ROWS_PER_STG }; + + // The number of active threads. + enum { ACTIVE_THREADS = ALL_THREADS_ACTIVE ? Cta_tile::THREADS_PER_CTA : ACTIVE_THREADS_ }; + + // Ctor. + template + inline __device__ Gmem_tile_o_8bit(Params const& params, Block_info const& block_info, int tidx, + int cta_row_offset = 0, int cta_col_offset_in_bytes = 0) + : params_o_stride_in_bytes_(params.o_stride_in_bytes), + actual_seqlen_(block_info.actual_q_seqlen), + params_scale_bmm2_(params.scale_bmm2_d ? *params.scale_bmm2_d : params.scale_bmm2) +#ifdef GENERATE_CUBIN + , + params_enable_i2f_trick_(false) +#else + , + params_enable_i2f_trick_(params.enable_i2f_trick) +#endif + , + o_ptr_(reinterpret_cast(params.o_ptr)) +#if USE_DEMO_BERT_PARAMS + , + o_scratch_ptr_(nullptr) { +#else + , + o_scratch_ptr_(reinterpret_cast(params.o_scratch_ptr)) { +#endif + + // Compute the position in the sequence (within the CTA for the moment). + int row = tidx / THREADS_PER_ROW; + // Compute the position of the thread in the row. + int col = tidx % THREADS_PER_ROW; + + // Is it an active thread for the very last STG? + if (HAS_INCOMPLETE_STG) { + is_active_for_last_stg_ = row + (STGS - 1) * ROWS_PER_STG < Cta_tile::M; + } + + // Store the row to check against the length before loads. + row_ = cta_row_offset + row; + col_in_bytes_ = cta_col_offset_in_bytes + col * BYTES_PER_STG; + + // The row offset in the batched GEMM. + int64_t row_offset = (int64_t)row_ * params.o_stride_in_bytes; + // The amount of bytes per row without padding (runtime). + int const valid_bytes_per_row = params.dv * BYTES_PER_ELEMENT; + // Take the batch/head offset into account. + row_offset += block_info.bidx * valid_bytes_per_row; + // Assemble the final pointer. + o_ptr_ += row_offset + col_in_bytes_; + + // Is it an active thread? + is_active_ = ALL_THREADS_ACTIVE || (row < ROWS_PER_LOOP && col_in_bytes_ < VALID_BYTES_PER_ROW); + + // Do not store if the thread is in the padded area + is_active_ = is_active_ && col < valid_bytes_per_row / BYTES_PER_STG; + + // For the scratch space, the pointer has int32 type so it accounts for the *4 factor. + o_scratch_ptr_ += blockIdx.y * STGS_PER_LOOP * ACTIVE_THREADS + tidx; + } + + // Load data from global memory. + inline __device__ void load(uint4 (&dst)[STGS_PER_LOOP], int mi) { + if (blockIdx.x == 0) { +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + dst[ii] = make_uint4(0u, 0u, 0u, 0u); + } + } else if (ALL_THREADS_ACTIVE || is_active_) { +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + fmha::ldg(dst[ii], o_scratch_ptr_ + ii * ACTIVE_THREADS); + } + } + } + + // Store data to global memory. + inline __device__ void store(uint4 const (&src)[STGS_PER_LOOP], int mi) { +// Iterate over the different STGs. +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + // Break early if we exceed s_i... + int jj = mi * STGS_PER_LOOP + ii; + if (row_ + jj * ROWS_PER_STG >= actual_seqlen_) { + return; + } + using Src_type = typename Traits::Accumulator_type; + using Dst_type = typename Traits::C_type; + // Packs the 32bit/16bit values to 8bit. + // Depending on the type, applies extra scaling with parameter scale_bmm2. + Stg_packed_type dst = Acc_packer::run(this, src[ii]); + float const* row_ptr = reinterpret_cast(&src[ii]); + + // Store the result. + if (is_active_ && (!HAS_INCOMPLETE_STG || (jj < STGS - 1 || is_active_for_last_stg_))) { + fmha::stg(o_ptr_ + jj * ROWS_PER_STG * params_o_stride_in_bytes_, dst); + } + } + } + + // Store data to global memory. + // TODO: 16bit (half) + inline __device__ void store(uint4 const (&src)[STGS_PER_LOOP], uint4 const (&old)[STGS_PER_LOOP], + int mi) { + // Do the reduction. + uint4 tmp[STGS_PER_LOOP]; +#if defined(USE_I2F_EMULATION_TRICK) + if (params_enable_i2f_trick_) { +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + float4 const& src_ii = reinterpret_cast(src[ii]); + float4 const& old_ii = reinterpret_cast(old[ii]); + + float x = src_ii.x + old_ii.x; + float y = src_ii.y + old_ii.y; + float z = src_ii.z + old_ii.z; + float w = src_ii.w + old_ii.w; + + tmp[ii].x = reinterpret_cast(x); + tmp[ii].y = reinterpret_cast(y); + tmp[ii].z = reinterpret_cast(z); + tmp[ii].w = reinterpret_cast(w); + } + } else +#endif + { +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + int4 const& src_ii = reinterpret_cast(src[ii]); + int4 const& old_ii = reinterpret_cast(old[ii]); + + int32_t x = src_ii.x + old_ii.x; + int32_t y = src_ii.y + old_ii.y; + int32_t z = src_ii.z + old_ii.z; + int32_t w = src_ii.w + old_ii.w; + + tmp[ii].x = reinterpret_cast(x); + tmp[ii].y = reinterpret_cast(y); + tmp[ii].z = reinterpret_cast(z); + tmp[ii].w = reinterpret_cast(w); + } + } + + // The last CTA stores INT8 values to the final location. + if (blockIdx.x == CTAS_PER_HEAD - 1) { + this->store(tmp, mi); + + // Other CTAs store INT32 values to the scratch space. + } else if (ALL_THREADS_ACTIVE || is_active_) { +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + fmha::stg(o_scratch_ptr_ + ii * ACTIVE_THREADS, tmp[ii]); + } + } + } + + // Move the pointer. + inline __device__ void move() { + row_ += ROWS; + o_ptr_ += (int64_t)ROWS * params_o_stride_in_bytes_; + } + + // The stride between rows for the QKV matrice. + int64_t params_o_stride_in_bytes_; + // The scaling factor to convert to int8. + uint32_t const params_scale_bmm2_; + // Do we enable the i2f trick? + bool const params_enable_i2f_trick_; + // The pointer. + char* o_ptr_; + // The pointer to the scratch space to do the reduction (for CTAS_PER_HEAD > 1). + uint4* o_scratch_ptr_; + // The row, col stored by this thread (i.e. the position in that sequence). + int row_, col_in_bytes_; + // The size of the sequence length computed by that CTA. + int actual_seqlen_; + + // Is it an active thread? + int is_active_, is_active_for_last_stg_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o + : public Gmem_tile_o_8bit { + // The traits class. + using Traits = fmha::Volta_imma_int8_int32_traits; + // The base class. + using Base = Gmem_tile_o_8bit; + + // Ctor. + template + inline __device__ Gmem_tile_o(Params const& params, Block_info const& block_info, int tidx, + int cta_row_offset = 0, int cta_col_offset_in_bytes = 0) + : Base(params, block_info, tidx, cta_row_offset, cta_col_offset_in_bytes) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o + : public Gmem_tile_o_8bit { + // The traits class. + using Traits = fmha::Turing_imma_int8_int32_traits; + // The base class. + using Base = Gmem_tile_o_8bit; + + // Ctor. + template + inline __device__ Gmem_tile_o(Params const& params, Block_info const& block_info, int tidx, + int cta_row_offset = 0, int cta_col_offset_in_bytes = 0) + : Base(params, block_info, tidx, cta_row_offset, cta_col_offset_in_bytes) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o + : public Gmem_tile_o_8bit { + // The traits class. + using Traits = fmha::Ampere_imma_int8_int32_traits; + // The base class. + using Base = Gmem_tile_o_8bit; + + // Ctor. + template + inline __device__ Gmem_tile_o(Params const& params, Block_info const& block_info, int tidx, + int cta_row_offset = 0, int cta_col_offset_in_bytes = 0) + : Base(params, block_info, tidx, cta_row_offset, cta_col_offset_in_bytes) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o + : public Gmem_tile_o_8bit { + // The traits class. + using Traits = fmha::Ada_qmma_e4m3_fp32_traits; + // The base class. + using Base = Gmem_tile_o_8bit; + + // Ctor. + template + inline __device__ Gmem_tile_o(Params const& params, Block_info const& block_info, int tidx, + int cta_row_offset = 0, int cta_col_offset_in_bytes = 0) + : Base(params, block_info, tidx, cta_row_offset, cta_col_offset_in_bytes) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o + : public Gmem_tile_o_8bit { + // The traits class. + using Traits = fmha::Ada_qmma_e4m3_fp16_traits; + // The base class. + using Base = Gmem_tile_o_8bit; + + // Ctor. + template + inline __device__ Gmem_tile_o(Params const& params, Block_info const& block_info, int tidx, + int cta_row_offset = 0, int cta_col_offset_in_bytes = 0) + : Base(params, block_info, tidx, cta_row_offset, cta_col_offset_in_bytes) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o_16bit { + // This stores the fp32 accumulators of Ada_qmma_e4m3_fp32_traits as 16bit values to + // the global memory. + + static_assert(std::is_same::value); + static_assert(std::is_same::value || + std::is_same::value); + + using Mma_tile = typename Traits::template Mma_tile; + + // The size of each element. + enum { BYTES_PER_ELEMENT = 2 }; + + // The size of a row in bytes. + enum { BYTES_PER_ROW = Cta_tile::N * BYTES_PER_ELEMENT }; + + // The valid size of a row in bytes. + // Note: cross-attention kernels rely on head dim from runtime instead of from compile-time. + // This approach deviates from self-attention kernels. To explore a unified approach. + enum { VALID_BYTES_PER_ROW = Cta_tile::VALID_N * BYTES_PER_ELEMENT }; + + // The size of each STG. + enum { BYTES_PER_STG = 8 }; + + // The STG packed data type + using Stg_packed_type = typename Uint_from_size_in_bytes::Type; + + // The number of threads to store a "row" of the matrix. + enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_STG }; + + // The number of "rows" stored per STG. + enum { ROWS_PER_STG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + + // The number of "rows" stored per iteration of the loop. The output of 1 MMA. + enum { ROWS = Cta_tile::M }; + + // We want at least one output per thread (if possible). + enum { ROWS_PER_LOOP_ = ROWS <= 64 ? ROWS : (int)Min::VALUE }; + + // We also want to have "complete" MMAs. + enum { ROWS_PER_LOOP = Max::VALUE }; + + // The number of outer loop for the stores. + enum { LOOPS = fmha::Div_up::VALUE }; + + // DEBUG. + static_assert(ROWS % ROWS_PER_LOOP == 0, ""); + // END OF DEBUG. + + // Make sure the math is correct. + static_assert(ROWS_PER_LOOP >= (int)Mma_tile::M_PER_MMA_PER_CTA, ""); + + // Do we have to guard against partial writes/reads. + enum { HAS_INCOMPLETE_STG = Cta_tile::M % ROWS_PER_STG != 0 }; + + // The number of STGs needed to store a chunk of the Q matrix. + enum { STGS_PER_LOOP = fmha::Div_up::VALUE }; + + // The number of STGs needed to store a chunk of the Q matrix in total. + enum { STGS = STGS_PER_LOOP * LOOPS }; + + // Are all threads active? + enum { ALL_THREADS_ACTIVE = ROWS_PER_STG <= ROWS_PER_LOOP }; + + // The number of active threads. + enum { ACTIVE_THREADS_ = Cta_tile::THREADS_PER_CTA * ROWS_PER_LOOP / ROWS_PER_STG }; + + // The number of active threads. + enum { ACTIVE_THREADS = ALL_THREADS_ACTIVE ? Cta_tile::THREADS_PER_CTA : ACTIVE_THREADS_ }; + + // Ctor. + template + inline __device__ Gmem_tile_o_16bit(Params const& params, Block_info const& block_info, int tidx, + int cta_row_offset = 0, int cta_col_offset_in_bytes = 0) + : params_o_stride_in_bytes_(params.o_stride_in_bytes), + actual_seqlen_(block_info.actual_q_seqlen), + params_scale_bmm2_(params.scale_bmm2_d ? *params.scale_bmm2_d : params.scale_bmm2) +#ifdef GENERATE_CUBIN + , + params_enable_i2f_trick_(false) +#else + , + params_enable_i2f_trick_(params.enable_i2f_trick) +#endif + , + o_ptr_(reinterpret_cast(params.o_ptr)) +#if USE_DEMO_BERT_PARAMS + , + o_scratch_ptr_(nullptr) { +#else + , + o_scratch_ptr_(reinterpret_cast(params.o_scratch_ptr)) { +#endif + + // Compute the position in the sequence (within the CTA for the moment). + int row = tidx / THREADS_PER_ROW; + // Compute the position of the thread in the row. + int col = tidx % THREADS_PER_ROW; + + // Is it an active thread for the very last STG? + if (HAS_INCOMPLETE_STG) { + is_active_for_last_stg_ = row + (STGS - 1) * ROWS_PER_STG < Cta_tile::M; + } + + // Store the row to check against the length before loads. + row_ = cta_row_offset + row; + col_in_bytes_ = cta_col_offset_in_bytes + col * BYTES_PER_STG; + + // The row offset in the batched GEMM. + int64_t row_offset = (int64_t)row_ * params.o_stride_in_bytes; + // The amount of bytes per row without padding (runtime). + int const valid_bytes_per_row = params.dv * BYTES_PER_ELEMENT; + // Take the batch/head offset into account. + row_offset += block_info.bidx * valid_bytes_per_row; + // Assemble the final pointer. + o_ptr_ += row_offset + col_in_bytes_; + + // Is it an active thread? + is_active_ = ALL_THREADS_ACTIVE || (row < ROWS_PER_LOOP && col_in_bytes_ < VALID_BYTES_PER_ROW); + + // Do not store if the thread is in the padded area + is_active_ = is_active_ && col < valid_bytes_per_row / BYTES_PER_STG; + + // For the scratch space, the pointer has int32 type so it accounts for the *4 factor. + o_scratch_ptr_ += blockIdx.y * STGS_PER_LOOP * ACTIVE_THREADS + tidx; + } + + // Store data to global memory. + inline __device__ void store(uint4 const (&src)[STGS_PER_LOOP], int mi) { +// Iterate over the different STGs. +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + // Break early if we exceed s_i... + int jj = mi * STGS_PER_LOOP + ii; + if (row_ + jj * ROWS_PER_STG >= actual_seqlen_) { + return; + } + using Src_type = typename Traits::Accumulator_type; + // Packs the 32bit/16bit values to 16bit. + // Depending on the type, applies extra scaling with parameter scale_bmm2. + Stg_packed_type dst = Acc_packer::run(this, src[ii]); + float const* row_ptr = reinterpret_cast(&src[ii]); + + // Store the result. + if (is_active_ && (!HAS_INCOMPLETE_STG || (jj < STGS - 1 || is_active_for_last_stg_))) { + fmha::stg(o_ptr_ + jj * ROWS_PER_STG * params_o_stride_in_bytes_, dst); + } + } + } + + // The stride between rows for the QKV matrice. + int64_t params_o_stride_in_bytes_; + // The scaling factor to convert to int8. + uint32_t const params_scale_bmm2_; + // Do we enable the i2f trick? + bool const params_enable_i2f_trick_; + // The pointer. + char* o_ptr_; + // The pointer to the scratch space to do the reduction (for CTAS_PER_HEAD > 1). + uint4* o_scratch_ptr_; + // The row, col stored by this thread (i.e. the position in that sequence). + int row_, col_in_bytes_; + // The size of the sequence length computed by that CTA. + int actual_seqlen_; + + // Is it an active thread? + int is_active_, is_active_for_last_stg_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o_uint16 : public Gmem_tile_o_16bit { + using Base = Gmem_tile_o_16bit; + + // Ctor. + template + inline __device__ Gmem_tile_o_uint16(Params const& params, Block_info const& block_info, int tidx, + int cta_row_offset = 0, int cta_col_offset_in_bytes = 0) + : Base(params, block_info, tidx, cta_row_offset, cta_col_offset_in_bytes) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o_bfloat16 + : public Gmem_tile_o_16bit { + using Base = Gmem_tile_o_16bit; + + // Ctor. + template + inline __device__ Gmem_tile_o_bfloat16(Params const& params, Block_info const& block_info, + int tidx, int cta_row_offset = 0, + int cta_col_offset_in_bytes = 0) + : Base(params, block_info, tidx, cta_row_offset, cta_col_offset_in_bytes) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Imma_gmem_tile_o_interleaved { + // The mma tile. + using Mma_tile = typename Traits::template Mma_tile; + + enum { VEC = 32 }; + + enum { NUM_SLICES = Cta_tile::N / VEC }; + + // DEBUG. + static_assert(NUM_SLICES == 1 || NUM_SLICES == 2, ""); + + // END OF DEBUG. + + // The size of each element. + enum { BYTES_PER_ELEMENT = 1 }; + + // The size of a row in bytes. + enum { BYTES_PER_ROW = VEC * BYTES_PER_ELEMENT }; + + // The size of each STG. + enum { BYTES_PER_STG = 4 }; + + // The number of threads to store a "row" of the matrix. We force it to 8 + enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_STG }; + + // DEBUG. + static_assert(THREADS_PER_ROW == 8 && BYTES_PER_STG == 4, ""); + + // END OF DEBUG. + + // the "logical" number of rows. think of rows per slice + enum { ROWS = Cta_tile::M }; + + // "physical" rows + enum { TOTAL_ROWS = ROWS * NUM_SLICES }; + + // The number of "rows" stored per iteration of the loop. The output of 1 MMA. + enum { ROWS_PER_LOOP_PER_SLICE = Mma_tile::M_PER_MMA_PER_CTA }; + + enum { ROWS_PER_LOOP = Mma_tile::M_PER_MMA_PER_CTA * NUM_SLICES }; + + // DEBUG. + static_assert(ROWS_PER_LOOP == 16 * Cta_tile::WARPS_M * NUM_SLICES, ""); + + // END OF DEBUG. + + // The number of outer loop for the stores. + enum { LOOPS = TOTAL_ROWS / ROWS_PER_LOOP }; + + // Make sure the math is correct. + static_assert(LOOPS == (int)Mma_tile::MMAS_M, ""); + + // The number of "rows" stored per STG -- for it to be the number of rows per MMA instruction. + enum { ROWS_PER_STG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + + // The number of STGs needed to store a chunk of the Q matrix. + enum { STGS_PER_LOOP = fmha::Div_up::VALUE }; + + enum { STGS_PER_SLICE = STGS_PER_LOOP / NUM_SLICES }; + + // DEBUG. + static_assert((Cta_tile::WARPS_M == 1 && STGS_PER_SLICE == 1) || + (Cta_tile::WARPS_M == 2 && STGS_PER_SLICE == 2), + ""); + + // END OF DEBUG. + + // Ctor. + template + inline __device__ Imma_gmem_tile_o_interleaved(Params const& params, Block_info const& block_info, + int tidx, int cta_row_offset = 0) + : params_o_stride_in_bytes_(params.o_stride_in_bytes), + actual_seqlen_(block_info.actual_seqlen - cta_row_offset), + params_scale_bmm2_(params.scale_bmm2), + params_enable_i2f_trick_(params.enable_i2f_trick), + o_ptr_(reinterpret_cast(params.o_ptr)), + total_(params.o_stride_in_bytes) { + int bidh = block_info.bidh; + int sum_s = block_info.sum_s; + + row_ = tidx / THREADS_PER_ROW; + int col = tidx % THREADS_PER_ROW; + + // h is N + // d is H + // want to save as: h x (d/32) x total x 32 (think 3 x h x (d/32) x b x s x 32) + + int block_offset = bidh * NUM_SLICES * total_ + sum_s; // bidh * GROUPS * B * S + b * S + int row_offset = (block_offset + cta_row_offset) * BYTES_PER_ROW; + + o_ptr_ += row_offset + col * BYTES_PER_STG; + } + + // Load data from global memory. + inline __device__ void load(uint4 (&dst)[STGS_PER_LOOP], int mi) { + static_assert(CTAS_PER_HEAD == 1, "Not implemented"); + } + + // Store data to global memory. + inline __device__ void store(uint4 const (&src)[STGS_PER_LOOP], int mi) { + int rows_so_far = mi * STGS_PER_LOOP * ROWS_PER_STG; + int rows_so_far_per_slice = rows_so_far / 2; + + // The scale. + float const& scale = reinterpret_cast(params_scale_bmm2_); + +// Iterate over the different STGs. +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + // if(ii == 1) return; + // decompose the iteration into slice + int slice = ii / STGS_PER_SLICE; + int si = ii % STGS_PER_SLICE; + // dbg 256 + // assert(STGS_PER_SLICE == 1); + // assert(STGS_PER_LOOP == 2); + // assert(slice == ii); + // the number of rows one CTA-wide STG writes + static_assert(ROWS_PER_STG == 16, ""); // only holds for 4 warps/128 threads + int row_in_slice = row_ + si * ROWS_PER_STG + rows_so_far_per_slice; + + // we cannot return early, because the second half of iterates are + // responsible for the bottom slice + if (row_in_slice >= min(actual_seqlen_, ROWS)) { + continue; + } + + int offset = (slice * total_ + row_in_slice) * BYTES_PER_ROW; + + // The accumulators are in int32_t. + int4 const& val = reinterpret_cast(src[ii]); + + // if(threadIdx.x == 96){ + // printf("mi=%d ii=%d S=%d si=%d sofar=%d row=%d as=%d\n", mi, ii, slice, si, + // rows_so_far_per_slice, row_in_slice, actual_seqlen_) ; + // } + + uint32_t dst = quantize(val, scale, params_enable_i2f_trick_); + // Store the result. + fmha::stg(o_ptr_ + offset, dst); + } + } + + // Store data to global memory. + inline __device__ void store(uint4 const (&src)[STGS_PER_LOOP], uint4 const (&old)[STGS_PER_LOOP], + int mi) { + static_assert(CTAS_PER_HEAD == 1, "Not implemented"); + } + + // Move the pointer. + inline __device__ void move() { + o_ptr_ += (int64_t)ROWS * BYTES_PER_ROW; + actual_seqlen_ -= ROWS; + } + + // The stride between rows for the QKV matrice. + int64_t const params_o_stride_in_bytes_; + // The scaling factor to convert to int8. + uint32_t const params_scale_bmm2_; + // Do we enable the i2f trick? + bool const params_enable_i2f_trick_; + // The pointer. + char* o_ptr_; + int row_; + int actual_seqlen_; + int total_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace v2 +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/gmem_tile_ps.h b/csrc/fmha_v2/fmha/gmem_tile_ps.h new file mode 100644 index 0000000000..de150ff293 --- /dev/null +++ b/csrc/fmha_v2/fmha/gmem_tile_ps.h @@ -0,0 +1,837 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once +#include + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Store_accumulator {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Store_accumulator { + // The fragment. + using Acc = fmha::Fragment_accumulator; + + // Store. + inline __device__ void store(char* ptr, int64_t step_m, int64_t step_n, Acc const& acc, + uint32_t scale) { + uint32_t acc_0 = fmha::hmul2(acc.reg(0), scale); + uint32_t acc_1 = fmha::hmul2(acc.reg(1), scale); + uint32_t acc_2 = fmha::hmul2(acc.reg(2), scale); + uint32_t acc_3 = fmha::hmul2(acc.reg(3), scale); + + fmha::stg(ptr + 0 * step_m + 0 * step_n, acc_0); + fmha::stg(ptr + 1 * step_m + 0 * step_n, acc_1); + fmha::stg(ptr + 0 * step_m + 1 * step_n, acc_2); + fmha::stg(ptr + 1 * step_m + 1 * step_n, acc_3); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Store_accumulator { + // The fragment. + using Accumulator = fmha::Fragment_accumulator; + + // Store. + inline __device__ void store(char* ptr, int64_t step_m, int64_t step_n, Accumulator const& acc, + uint32_t) { + int32_t tmp_0 = acc.elt(0); + int32_t tmp_1 = acc.elt(1); + int32_t tmp_2 = acc.elt(2); + int32_t tmp_3 = acc.elt(3); + int32_t tmp_4 = acc.elt(4); + int32_t tmp_5 = acc.elt(5); + int32_t tmp_6 = acc.elt(6); + int32_t tmp_7 = acc.elt(7); + +#if defined(USE_I2F_EMULATION_TRICK) + tmp_0 -= int32_t(FP32_I2F_MAGIC_NUMBER_HEX); + tmp_1 -= int32_t(FP32_I2F_MAGIC_NUMBER_HEX); + tmp_2 -= int32_t(FP32_I2F_MAGIC_NUMBER_HEX); + tmp_3 -= int32_t(FP32_I2F_MAGIC_NUMBER_HEX); + tmp_4 -= int32_t(FP32_I2F_MAGIC_NUMBER_HEX); + tmp_5 -= int32_t(FP32_I2F_MAGIC_NUMBER_HEX); + tmp_6 -= int32_t(FP32_I2F_MAGIC_NUMBER_HEX); + tmp_7 -= int32_t(FP32_I2F_MAGIC_NUMBER_HEX); +#endif + + uint32_t acc_0 = reinterpret_cast(tmp_0); + uint32_t acc_1 = reinterpret_cast(tmp_1); + uint32_t acc_2 = reinterpret_cast(tmp_2); + uint32_t acc_3 = reinterpret_cast(tmp_3); + uint32_t acc_4 = reinterpret_cast(tmp_4); + uint32_t acc_5 = reinterpret_cast(tmp_5); + uint32_t acc_6 = reinterpret_cast(tmp_6); + uint32_t acc_7 = reinterpret_cast(tmp_7); + + fmha::stg(ptr + 0 * step_m + 0 * step_n, make_uint2(acc_0, acc_1)); + fmha::stg(ptr + 1 * step_m + 0 * step_n, make_uint2(acc_4, acc_5)); + fmha::stg(ptr + 0 * step_m + 1 * step_n, make_uint2(acc_2, acc_3)); + fmha::stg(ptr + 1 * step_m + 1 * step_n, make_uint2(acc_6, acc_7)); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Store_accumulator { + // The instruction traits. + using Traits = Ampere_hmma_fp32_traits; + // The fragment. + using Accumulator = fmha::Fragment_accumulator; + + // Store. + inline __device__ void store(char* ptr, int64_t step_m, int64_t step_n, Accumulator const& acc, + uint32_t scale) { + float const scalef = reinterpret_cast(scale); + + float const tmp_0 = acc.elt(0) * scalef; + float const tmp_1 = acc.elt(1) * scalef; + float const tmp_2 = acc.elt(2) * scalef; + float const tmp_3 = acc.elt(3) * scalef; + float const tmp_4 = acc.elt(4) * scalef; + float const tmp_5 = acc.elt(5) * scalef; + float const tmp_6 = acc.elt(6) * scalef; + float const tmp_7 = acc.elt(7) * scalef; + + uint32_t acc_0 = reinterpret_cast(tmp_0); + uint32_t acc_1 = reinterpret_cast(tmp_1); + uint32_t acc_2 = reinterpret_cast(tmp_2); + uint32_t acc_3 = reinterpret_cast(tmp_3); + uint32_t acc_4 = reinterpret_cast(tmp_4); + uint32_t acc_5 = reinterpret_cast(tmp_5); + uint32_t acc_6 = reinterpret_cast(tmp_6); + uint32_t acc_7 = reinterpret_cast(tmp_7); + + fmha::stg(ptr + 0 * step_m + 0 * step_n, make_uint2(acc_0, acc_1)); + fmha::stg(ptr + 1 * step_m + 0 * step_n, make_uint2(acc_2, acc_3)); + fmha::stg(ptr + 0 * step_m + 1 * step_n, make_uint2(acc_4, acc_5)); + fmha::stg(ptr + 1 * step_m + 1 * step_n, make_uint2(acc_6, acc_7)); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Store_accumulator { + // The instruction traits. + using Traits = Ampere_hmma_bf16_traits; + // The fragment. + using Accumulator = fmha::Fragment_accumulator; + + // Store. + inline __device__ void store(char* ptr, int64_t step_m, int64_t step_n, Accumulator const& acc, + uint32_t scale) { + float const scalef = reinterpret_cast(scale); + + float const tmp_0 = acc.elt(0) * scalef; + float const tmp_1 = acc.elt(1) * scalef; + float const tmp_2 = acc.elt(2) * scalef; + float const tmp_3 = acc.elt(3) * scalef; + float const tmp_4 = acc.elt(4) * scalef; + float const tmp_5 = acc.elt(5) * scalef; + float const tmp_6 = acc.elt(6) * scalef; + float const tmp_7 = acc.elt(7) * scalef; + + uint32_t acc_0 = reinterpret_cast(tmp_0); + uint32_t acc_1 = reinterpret_cast(tmp_1); + uint32_t acc_2 = reinterpret_cast(tmp_2); + uint32_t acc_3 = reinterpret_cast(tmp_3); + uint32_t acc_4 = reinterpret_cast(tmp_4); + uint32_t acc_5 = reinterpret_cast(tmp_5); + uint32_t acc_6 = reinterpret_cast(tmp_6); + uint32_t acc_7 = reinterpret_cast(tmp_7); + + fmha::stg(ptr + 0 * step_m + 0 * step_n, make_uint2(acc_0, acc_1)); + fmha::stg(ptr + 1 * step_m + 0 * step_n, make_uint2(acc_2, acc_3)); + fmha::stg(ptr + 0 * step_m + 1 * step_n, make_uint2(acc_4, acc_5)); + fmha::stg(ptr + 1 * step_m + 1 * step_n, make_uint2(acc_6, acc_7)); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint16_t pack_char2(uint32_t a, uint32_t b) { + uint32_t dst; + asm volatile("prmt.b32 %0, %1, %2, 0x0040;\n" : "=r"(dst) : "r"(a), "r"(b)); + return reinterpret_cast(dst); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Store_accumulator { + // The fragment. + using Accumulator = fmha::Fragment_accumulator; + + // Store. + inline __device__ void store(char* ptr, int64_t step_m, int64_t step_n, Accumulator const& acc, + uint32_t) { + // Pack pairs of values. + uint16_t tmp_00 = pack_char2(acc.reg(0), acc.reg(1)); + uint16_t tmp_01 = pack_char2(acc.reg(2), acc.reg(3)); + uint16_t tmp_10 = pack_char2(acc.reg(4), acc.reg(5)); + uint16_t tmp_11 = pack_char2(acc.reg(6), acc.reg(7)); + + // Store to memory. + fmha::stg(ptr + 0 * step_m + 0 * step_n, tmp_00); + fmha::stg(ptr + 1 * step_m + 0 * step_n, tmp_10); + fmha::stg(ptr + 0 * step_m + 1 * step_n, tmp_01); + fmha::stg(ptr + 1 * step_m + 1 * step_n, tmp_11); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Store_accumulator { + // The traits. + using Traits = fmha::Ampere_imma_int8_int32_traits; + // The fragment. + using Accumulator = fmha::Fragment_accumulator; + + // Store. + inline __device__ void store(char* ptr, int64_t step_m, int64_t step_n, Accumulator const& acc, + uint32_t) { + int32_t tmp_0 = acc.elt(0); + int32_t tmp_1 = acc.elt(1); + int32_t tmp_2 = acc.elt(2); + int32_t tmp_3 = acc.elt(3); + int32_t tmp_4 = acc.elt(4); + int32_t tmp_5 = acc.elt(5); + int32_t tmp_6 = acc.elt(6); + int32_t tmp_7 = acc.elt(7); + +#if defined(USE_I2F_EMULATION_TRICK) + tmp_0 -= int32_t(FP32_I2F_MAGIC_NUMBER_HEX); + tmp_1 -= int32_t(FP32_I2F_MAGIC_NUMBER_HEX); + tmp_2 -= int32_t(FP32_I2F_MAGIC_NUMBER_HEX); + tmp_3 -= int32_t(FP32_I2F_MAGIC_NUMBER_HEX); + tmp_4 -= int32_t(FP32_I2F_MAGIC_NUMBER_HEX); + tmp_5 -= int32_t(FP32_I2F_MAGIC_NUMBER_HEX); + tmp_6 -= int32_t(FP32_I2F_MAGIC_NUMBER_HEX); + tmp_7 -= int32_t(FP32_I2F_MAGIC_NUMBER_HEX); +#endif + + uint32_t acc_0 = reinterpret_cast(tmp_0); + uint32_t acc_1 = reinterpret_cast(tmp_1); + uint32_t acc_2 = reinterpret_cast(tmp_2); + uint32_t acc_3 = reinterpret_cast(tmp_3); + uint32_t acc_4 = reinterpret_cast(tmp_4); + uint32_t acc_5 = reinterpret_cast(tmp_5); + uint32_t acc_6 = reinterpret_cast(tmp_6); + uint32_t acc_7 = reinterpret_cast(tmp_7); + + fmha::stg(ptr + 0 * step_m + 0 * step_n, make_uint2(acc_0, acc_1)); + fmha::stg(ptr + 1 * step_m + 0 * step_n, make_uint2(acc_2, acc_3)); + fmha::stg(ptr + 0 * step_m + 1 * step_n, make_uint2(acc_4, acc_5)); + fmha::stg(ptr + 1 * step_m + 1 * step_n, make_uint2(acc_6, acc_7)); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Store_accumulator { + // The traits. + using Traits = fmha::Ampere_imma_int8_int32_traits; + // The fragment. + using Accumulator = fmha::Fragment_accumulator; + + // Store. + inline __device__ void store(char* ptr, int64_t step_m, int64_t step_n, Accumulator const& acc, + uint32_t) { + // Pack pairs of values. + uint16_t tmp_00 = pack_char2(acc.reg(0), acc.reg(1)); + uint16_t tmp_01 = pack_char2(acc.reg(4), acc.reg(5)); + uint16_t tmp_10 = pack_char2(acc.reg(2), acc.reg(3)); + uint16_t tmp_11 = pack_char2(acc.reg(6), acc.reg(7)); + + // Store to memory. + fmha::stg(ptr + 0 * step_m + 0 * step_n, tmp_00); + fmha::stg(ptr + 1 * step_m + 0 * step_n, tmp_10); + fmha::stg(ptr + 0 * step_m + 1 * step_n, tmp_01); + fmha::stg(ptr + 1 * step_m + 1 * step_n, tmp_11); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Store_accumulator< + fmha::Hopper_hgmma_fp16_traits, 16> { + // The traits. + using Traits = fmha::Hopper_hgmma_fp16_traits; + // The fragment. + using Accumulator = fmha::Fragment_accumulator; + + // The number of rows accessed by each thread. + enum { ROWS_PER_THREAD = GMMA_M / 8 / 4 }; + + // The number of columns access by each thread. + // Note there are 2 elements per reg. + enum { COLUMNS_PER_THREAD = GMMA_N / 4 / 2 }; + + // The number of accumulator held by each thread, per HGMMA instruction. + enum { ELEMENT_PER_THREAD = ROWS_PER_THREAD * COLUMNS_PER_THREAD }; + + // Store. + inline __device__ void store(char* ptr, int64_t step_m, int64_t step_n, Accumulator const& acc, + uint32_t scale) { +#pragma unroll + for (int col_idx = 0; col_idx < COLUMNS_PER_THREAD; ++col_idx) { +#pragma unroll + for (int row_idx = 0; row_idx < ROWS_PER_THREAD; ++row_idx) { + uint32_t acc_0 = fmha::hmul2(acc.reg(col_idx * ROWS_PER_THREAD + row_idx), scale); + // float one = 1.f; + // if(col_idx > 2){ + // acc_0 = float2_to_half2(one, one); + // } + int64_t offset = (int64_t)row_idx * step_m + (int64_t)col_idx * step_n; + fmha::stg(ptr + offset, acc_0); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Store_accumulator< + fmha::Hopper_qgmma_fp8_fp32_traits, + 32> { + // The traits. + using Traits = fmha::Hopper_qgmma_fp8_fp32_traits; + // The fragment. + using Accumulator = fmha::Fragment_accumulator; + + // The number of rows accessed by each thread. + enum { ROWS_PER_THREAD = GMMA_M / 8 / 4 }; + + // The number of columns access by each thread. + // Note there are 2 elements per reg. + enum { COLUMNS_PER_THREAD = GMMA_N / 8 }; + + // The number of accumulator held by each thread, per HGMMA instruction. + enum { ELEMENT_PER_THREAD = ROWS_PER_THREAD * COLUMNS_PER_THREAD }; + + // Store. + inline __device__ void store(char* ptr, int64_t step_m, int64_t step_n, Accumulator const& acc, + uint32_t scale) { + float const scalef = reinterpret_cast(scale); +#pragma unroll + for (int col_idx = 0; col_idx < COLUMNS_PER_THREAD; ++col_idx) { +#pragma unroll + for (int row_idx = 0; row_idx < ROWS_PER_THREAD; ++row_idx) { + float const acc_0 = acc.elt((col_idx * ROWS_PER_THREAD + row_idx) * 2 + 0) * scalef; + float const acc_1 = acc.elt((col_idx * ROWS_PER_THREAD + row_idx) * 2 + 1) * scalef; + uint2 acc_; + acc_.x = reinterpret_cast(acc_0); + acc_.y = reinterpret_cast(acc_1); + int64_t offset = (int64_t)row_idx * step_m + (int64_t)col_idx * step_n; + fmha::stg(ptr + offset, acc_); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Store_accumulator< + fmha::Hopper_igmma_int8_int32_traits, 32> { + // The traits. + using Traits = fmha::Hopper_igmma_int8_int32_traits; + // The fragment. + using Accumulator = fmha::Fragment_accumulator; + + // The number of rows accessed by each thread. + enum { ROWS_PER_THREAD = GMMA_M / 8 / 4 }; + + // The number of columns access by each thread. + // Note there are 2 elements per reg. + enum { COLUMNS_PER_THREAD = GMMA_N / 8 }; + + // The number of accumulator held by each thread, per HGMMA instruction. + enum { ELEMENT_PER_THREAD = ROWS_PER_THREAD * COLUMNS_PER_THREAD }; + + // Store. + inline __device__ void store(char* ptr, int64_t step_m, int64_t step_n, Accumulator const& acc, + uint32_t scale) { +#pragma unroll + for (int col_idx = 0; col_idx < COLUMNS_PER_THREAD; ++col_idx) { +#pragma unroll + for (int row_idx = 0; row_idx < ROWS_PER_THREAD; ++row_idx) { + int32_t const acc_0 = acc.elt((col_idx * ROWS_PER_THREAD + row_idx) * 2 + 0); + int32_t const acc_1 = acc.elt((col_idx * ROWS_PER_THREAD + row_idx) * 2 + 1); + uint2 acc_; + acc_.x = reinterpret_cast(acc_0); + acc_.y = reinterpret_cast(acc_1); + int64_t offset = (int64_t)row_idx * step_m + (int64_t)col_idx * step_n; + fmha::stg(ptr + offset, acc_); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static __device__ inline uint16_t pack_e4m3x2(float const x, float const y) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + uint16_t storage; + asm volatile("{cvt.rn.satfinite.e4m3x2.f32 %0, %2, %1;}\n" : "=h"(storage) : "f"(x), "f"(y)); + return storage; +#else + assert(false); + return 0; +#endif +} + +static __device__ inline uint16_t pack_e5m2x2(float const x, float const y) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + uint16_t storage; + asm volatile("{cvt.rn.satfinite.e5m2x2.f32 %0, %2, %1;}\n" : "=h"(storage) : "f"(x), "f"(y)); + return storage; +#else + assert(false); + return 0; +#endif +} + +template +__device__ inline uint16_t pack_fp8x2(float const x, float const y); + +template <> +__device__ inline uint16_t pack_fp8x2(float const x, float const y) { + return pack_e4m3x2(x, y); +} + +template <> +__device__ inline uint16_t pack_fp8x2(float const x, float const y) { + return pack_e5m2x2(x, y); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Store_accumulator< + fmha::Hopper_qgmma_fp8_fp32_traits, + 8> { + // The traits. + using Traits = fmha::Hopper_qgmma_fp8_fp32_traits; + // The fragment. + using Accumulator = fmha::Fragment_accumulator; + + // The number of rows accessed by each thread. + enum { ROWS_PER_THREAD = GMMA_M / 8 / 4 }; + + // The number of columns access by each thread. + // Note there are 2 elements per reg. + enum { COLUMNS_PER_THREAD = GMMA_N / 8 }; + + // The number of accumulator held by each thread, per HGMMA instruction. + enum { ELEMENT_PER_THREAD = ROWS_PER_THREAD * COLUMNS_PER_THREAD }; + + // Store. + inline __device__ void store(char* ptr, int64_t step_m, int64_t step_n, Accumulator const& acc, + uint32_t) { +#pragma unroll + for (int col_idx = 0; col_idx < COLUMNS_PER_THREAD; ++col_idx) { +#pragma unroll + for (int row_idx = 0; row_idx < ROWS_PER_THREAD; ++row_idx) { + float const acc_0 = acc.elt((col_idx * ROWS_PER_THREAD + row_idx) * 2 + 0); + float const acc_1 = acc.elt((col_idx * ROWS_PER_THREAD + row_idx) * 2 + 1); + // uint16_t acc_ = pack_e4m3x2(acc_0, acc_1); + uint16_t acc_ = pack_fp8x2(acc_0, acc_1); + int64_t offset = (int64_t)row_idx * step_m + (int64_t)col_idx * step_n; + fmha::stg(ptr + offset, acc_); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Store_accumulator< + fmha::Hopper_igmma_int8_int32_traits, 8> { + // The traits. + using Traits = fmha::Hopper_igmma_int8_int32_traits; + // The fragment. + using Accumulator = fmha::Fragment_accumulator; + + // The number of rows accessed by each thread. + enum { ROWS_PER_THREAD = GMMA_M / 8 / 4 }; + + // The number of columns access by each thread. + // Note there are 2 elements per reg. + enum { COLUMNS_PER_THREAD = GMMA_N / 8 }; + + // The number of accumulator held by each thread, per HGMMA instruction. + enum { ELEMENT_PER_THREAD = ROWS_PER_THREAD * COLUMNS_PER_THREAD }; + + // Store. + inline __device__ void store(char* ptr, int64_t step_m, int64_t step_n, Accumulator const& acc, + uint32_t) { +#pragma unroll + for (int col_idx = 0; col_idx < COLUMNS_PER_THREAD; ++col_idx) { +#pragma unroll + for (int row_idx = 0; row_idx < ROWS_PER_THREAD; ++row_idx) { + uint32_t const acc_0 = acc.reg((col_idx * ROWS_PER_THREAD + row_idx) * 2 + 0); + uint32_t const acc_1 = acc.reg((col_idx * ROWS_PER_THREAD + row_idx) * 2 + 1); + uint16_t acc_ = pack_char2(acc_0, acc_1); + int64_t offset = (int64_t)row_idx * step_m + (int64_t)col_idx * step_n; + fmha::stg(ptr + offset, acc_); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_ps { + // The associated MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of elements per STG. + enum { ELEMENTS_PER_STG = 2 }; + + // The size in bytes of each element. + enum { BYTES_PER_ELEMENT = BITS_PER_ELEMENT / 8 }; + + // The size of each STG. + enum { BYTES_PER_STG = ELEMENTS_PER_STG * BYTES_PER_ELEMENT }; + + // The size of a row in bytes. + enum { BYTES_PER_ROW = Cta_tile::N * BYTES_PER_ELEMENT }; + + // // DEBUG. + // static_assert(BYTES_PER_ROW == 384 || BYTES_PER_ROW == 768 || BYTES_PER_ROW == 1536, ""); + // // END OF DEBUG. + + // Ctor. + inline __device__ Gmem_tile_ps(void* ptr, int64_t const params_stride_in_bytes, + uint32_t const params_scale, int tidx, int cta_row_offset = 0) + : params_stride_in_bytes_(params_stride_in_bytes), + params_scale_(params_scale), + ptr_(reinterpret_cast(ptr)) { + // For storing P and S, we do not take into account variable sequence length. + + // The block index for the batch. + int const bidb = blockIdx.y; + // The block index for the head. + int const bidh = blockIdx.x; + // The block index. + int bidx = bidb * gridDim.x + bidh; + + // Decompose the position of the thread into warp/lane. + int warp = tidx / Cta_tile::THREADS_PER_WARP; + int lane = tidx % Cta_tile::THREADS_PER_WARP; + + // Compute the position in the sequence (within the CTA for the moment). + int row = warp % Cta_tile::WARPS_M * Mma_tile::M_PER_MMA + lane / 4 + cta_row_offset; + // Compute the position of the thread in the row. + int col = warp / Cta_tile::WARPS_M * Mma_tile::N_PER_MMA + lane % 4 * ELEMENTS_PER_STG; + + // The offset of the 1st row written by the thread. We store the P matrix interleaved. + int64_t row_offset = (int64_t)row * params_stride_in_bytes_ + bidx * BYTES_PER_ROW; + // Finalize the pointer. + ptr_ += row_offset + col * BYTES_PER_ELEMENT; + } + + // Store data to memory. + template + inline __device__ void store(Accumulators const (&acc)[M][N]) { + // A thread holds packet of 2 elements. In 2x2 tile per MMA. + int64_t const step_m = 8 * params_stride_in_bytes_; + int64_t const step_n = 8 * BYTES_PER_ELEMENT; + +// Store the different accumulators. +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ni = 0; ni < N; ++ni) { + int64_t offset = (int64_t)mi * Mma_tile::M_PER_MMA_PER_CTA * params_stride_in_bytes_ + + ni * Mma_tile::N_PER_MMA_PER_CTA * BYTES_PER_ELEMENT; + Store_accumulator delegate; + delegate.store(ptr_ + offset, step_m, step_n, acc[mi][ni], params_scale_); + } + } + } + + // Move to the next location. + inline __device__ void move() { ptr_ += (int64_t)Cta_tile::M * params_stride_in_bytes_; } + + inline __device__ void move_n() { ptr_ += (int64_t)Cta_tile::N * BYTES_PER_ELEMENT; } + + // The stride between rows for the QKV matrice. + int64_t const params_stride_in_bytes_; + // The scale to apply before storing the element. + uint32_t const params_scale_; + // The pointer. + char* ptr_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_ps { + // The traits class. + using Traits = Volta_hmma_fp16_traits; + // The associated MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of elements per STG. + enum { ELEMENTS_PER_STG = 4 }; + + // The size in bytes of each element. + enum { BYTES_PER_ELEMENT = 2 }; + + // The size of each STG. + enum { BYTES_PER_STG = ELEMENTS_PER_STG * BYTES_PER_ELEMENT }; + + // The size of a row in bytes. + enum { BYTES_PER_ROW = Cta_tile::N * BYTES_PER_ELEMENT }; + + // Ctor. + inline __device__ Gmem_tile_ps(void* ptr, int64_t const params_stride_in_bytes, + uint32_t const params_scale, int tidx, int cta_row_offset = 0) + : params_stride_in_bytes_(params_stride_in_bytes), + params_scale_(params_scale), + ptr_(reinterpret_cast(ptr)) { + // For storing P and S, we do not take into account variable sequence lengths. + + // The block index for the batch. + int const bidb = blockIdx.y; + // The block index for the head. + int const bidh = blockIdx.x; + // The block index. + int bidx = bidb * gridDim.x + bidh; + + // Decompose the position of the thread into warp/lane. + int warp = tidx / Cta_tile::THREADS_PER_WARP; + int lane = tidx % Cta_tile::THREADS_PER_WARP; + + // DEBUG. + static_assert(Mma_tile::M_PER_MMA == 16 && Mma_tile::N_PER_MMA == 16, ""); + // END OF DEBUG. + + // The position of the warp. + int warp_row = warp % Cta_tile::WARPS_M * Mma_tile::M_PER_MMA; + int warp_col = warp / Cta_tile::WARPS_M * Mma_tile::N_PER_MMA; + + // Compute the position of the thread (within the CTA for the moment). + int row = warp_row + (lane & 0x10) / 2 + (lane & 0x07); + int col = warp_col + (lane & 0x08) / 2; + + // // DEBUG. + // printf("tidx=%3d row=%3d col=%3d\n", tidx, row, col); + // // END OF DEBUG. + + // The offset of the 1st row written by the thread. We store the P matrix interleaved. + int64_t row_offset = + (int64_t)row * params_stride_in_bytes_ + bidx * BYTES_PER_ROW + cta_row_offset; + + // Finalize the pointer. + ptr_ += row_offset + col * BYTES_PER_ELEMENT; + } + + // Store data to memory. + template + inline __device__ void store(Accumulators const (&acc)[M][N]) { +// Store the different accumulators. +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ni = 0; ni < N; ++ni) { + // Scale the accumulators. + uint32_t acc_0 = fmha::hmul2(acc[mi][ni].reg(0), params_scale_); + uint32_t acc_1 = fmha::hmul2(acc[mi][ni].reg(1), params_scale_); + uint32_t acc_2 = fmha::hmul2(acc[mi][ni].reg(2), params_scale_); + uint32_t acc_3 = fmha::hmul2(acc[mi][ni].reg(3), params_scale_); + + // The offsets. + int row = mi * Mma_tile::M_PER_MMA_PER_CTA; + int col = ni * Mma_tile::N_PER_MMA_PER_CTA * BYTES_PER_ELEMENT; + + // The offset in bytes. + int64_t offset = (int64_t)row * params_stride_in_bytes_ + col; + + // In one MMA, 16 FP16s are interleaved between threads i and i+8 in groups of 4. + fmha::stg(&ptr_[offset + 0 * BYTES_PER_ELEMENT], make_uint2(acc_0, acc_1)); + fmha::stg(&ptr_[offset + 8 * BYTES_PER_ELEMENT], make_uint2(acc_2, acc_3)); + } + } + } + + // Move to the next location. + inline __device__ void move() { ptr_ += (int64_t)Cta_tile::M * params_stride_in_bytes_; } + + // The stride between rows for the QKV matrice. + int64_t const params_stride_in_bytes_; + // The scale to apply before storing the element. + uint32_t const params_scale_; + // The pointer. + char* ptr_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_p : public Gmem_tile_ps { + // The base class. + using Base = Gmem_tile_ps; + + // Ctor. + inline __device__ Gmem_tile_p(void* ptr, int64_t const params_stride_in_bytes, + uint32_t const params_scale, int tidx, int cta_row_offset = 0) + : Base(ptr, params_stride_in_bytes, params_scale, tidx, cta_row_offset) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Not super proud of this. Need to refactor. +template +struct Gmem_tile_ps_hopper { + // The associated MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of elements per STG. + enum { ELEMENTS_PER_STG = 2 }; + + // The size in bytes of each element. + enum { BYTES_PER_ELEMENT = BITS_PER_ELEMENT / 8 }; + + // The size of each STG. + enum { BYTES_PER_STG = ELEMENTS_PER_STG * BYTES_PER_ELEMENT }; + + // The size of a row in bytes. + enum { BYTES_PER_ROW = Cta_tile::N * BYTES_PER_ELEMENT }; + + // Ctor. + inline __device__ Gmem_tile_ps_hopper(void* ptr, int64_t const params_stride_in_bytes, + int64_t const bytes_per_row, uint32_t const params_scale, + int tidx) + : params_stride_in_bytes_(params_stride_in_bytes), + params_scale_(params_scale), + ptr_(reinterpret_cast(ptr)) { + // For storing P and S, we do not take into account variable sequence length. + + // The block index for the batch. + int const bidb = blockIdx.y; + // The block index for the head. + int const bidh = blockIdx.x; + // The block index. + int bidx = bidb * gridDim.x + bidh; + + // Decompose the position of the thread into warp/lane. + int warp = tidx / Cta_tile::THREADS_PER_WARP; + int lane = tidx % Cta_tile::THREADS_PER_WARP; + + int warpgroup_idx = warp / 4; + int warp_idx_within_warpgroup = warp % 4; + + // Compute the position in the sequence (within the CTA for the moment). + int row = warp_idx_within_warpgroup * (Mma_tile::M_PER_MMA / 4) + lane / 4; + // Compute the position of the thread in the row. + int col = warpgroup_idx * Mma_tile::N_PER_MMA + lane % 4 * ELEMENTS_PER_STG; + + // The offset of the 1st row written by the thread. We store the P matrix interleaved. + int64_t row_offset = (int64_t)row * params_stride_in_bytes_ + bidx * bytes_per_row; + // Finalize the pointer. + ptr_ += row_offset + col * BYTES_PER_ELEMENT; + } + + // Ctor. + inline __device__ Gmem_tile_ps_hopper(void* ptr, int64_t const params_stride_in_bytes, + uint32_t const params_scale, int tidx) + : Gmem_tile_ps_hopper(ptr, params_stride_in_bytes, BYTES_PER_ROW, params_scale, tidx) {} + + // Store data to memory. + template + inline __device__ void store(Accumulators const (&acc)[M][N]) { + // A thread holds packet of 2 elements. In 2x2 tile per MMA. + // Need to figure out if we need this for hopper. + int64_t const step_m = 8 * (this->params_stride_in_bytes_); + int64_t const step_n = 8 * BYTES_PER_ELEMENT; + +// Store the different accumulators. +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ni = 0; ni < N; ++ni) { + int64_t offset = + (int64_t)mi * Mma_tile::M_PER_MMA_PER_CTA * (this->params_stride_in_bytes_) + + ni * Mma_tile::N_PER_MMA_PER_CTA * BYTES_PER_ELEMENT; + + Store_accumulator delegate; + delegate.store(this->ptr_ + offset, step_m, step_n, acc[mi][ni], this->params_scale_); + } + } + } + + // Move to the next location. + inline __device__ void move() { ptr_ += (int64_t)Cta_tile::M * params_stride_in_bytes_; } + + // The stride between rows for the QKV matrice. + int64_t const params_stride_in_bytes_; + // The scale to apply before storing the element. + uint32_t const params_scale_; + // The pointer. + char* ptr_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_s : public Gmem_tile_ps { + // The base class. + using Base = Gmem_tile_ps; + + // Ctor. + inline __device__ Gmem_tile_s(void* ptr, int64_t const params_stride_in_bytes, + uint32_t const params_scale, int tidx) + : Base(ptr, params_stride_in_bytes, params_scale, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_s + : public Gmem_tile_ps { + // The base class. + using Base = Gmem_tile_ps; + + // Ctor. + inline __device__ Gmem_tile_s(void* ptr, int64_t const params_stride_in_bytes, + uint32_t const params_scale, int tidx, int cta_row_offset = 0) + : Base(ptr, params_stride_in_bytes, + float_to_half2(reinterpret_cast(params_scale)), tidx, cta_row_offset) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/gmem_tile_qkv.h b/csrc/fmha_v2/fmha/gmem_tile_qkv.h new file mode 100644 index 0000000000..0c0af5c8e4 --- /dev/null +++ b/csrc/fmha_v2/fmha/gmem_tile_qkv.h @@ -0,0 +1,167 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +namespace fmha { +namespace v1 { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits, + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The number of bits per element. + int BITS_PER_ELEMENT, + // The number of rows of Q, K or V loaded by this tile. + int ROWS_, + // The number of columns. + int COLS, + // The number of valid columns + int VALID_COLS, + // Do we use LDGSTS? + bool USE_LDGSTS_, + // Are attention heads interleaved? + bool HEADS_INTERLEAVED, + // Number of matrices + int NUM_MATS = 3, + // Is sliding window attention used ? + bool SLIDING_WINDOW_ATTENTION = false> +struct Gmem_tile_qkv { + // The size of each LDG. + enum { BYTES_PER_LDG = 16 }; + + // The size of a row in bytes. + enum { BYTES_PER_ROW = COLS * BITS_PER_ELEMENT / 8 }; + + // The number of threads to load a "row" of the matrix. + enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDG }; + + // The valid size of a row in bytes. + enum { VALID_BYTES_PER_ROW = VALID_COLS * BITS_PER_ELEMENT / 8 }; + + // The valid number of threads to load a "row" of the matrix. + enum { VALID_THREADS_PER_ROW = VALID_BYTES_PER_ROW / BYTES_PER_LDG }; + + // The number of "rows" loaded per LDG. + enum { ROWS_PER_LDG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + + // The number of rows. + enum { ROWS = ROWS_ }; + + // The number of LDGs needed to load a chunk of the Q matrix. + enum { LDGS = fmha::Div_up::VALUE }; + + // The number of predicate registers. + enum { PRED_REGS = fmha::Compute_number_of_pred_regs::VALUE }; + + // Make sure we use a single register to store predicates. + static_assert(PRED_REGS == 1, ""); + + // We do not use LDGSTS (for the moment). + enum { USE_LDGSTS = USE_LDGSTS_ }; + + // Ctor. + template + inline __device__ Gmem_tile_qkv(Params const& params, int qkv_offset, Block_info const& binfo, + int tidx, int cta_row_offset = 0) + + // in PACKED_QKV, q_stride = k_stride = v_stride + : params_qkv_stride_in_bytes_(params.q_stride_in_bytes), + qkv_ptr_(reinterpret_cast(params.qkv_ptr)) { + // Compute the position in the sequence (within the CTA for the moment). + int row = tidx / THREADS_PER_ROW; + // Compute the position of the thread in the row. + int col = tidx % THREADS_PER_ROW; + + // Prepare predicates. + uint32_t preds[LDGS]; +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + preds[ii] = row + ii * ROWS_PER_LDG < ROWS; + } + + // Pack the predicates. + preds_[0] = fmha::pack_predicates(preds); + + // The row offset in the batched GEMM. For each seq element, we store QKV in that order. + int64_t row_offset = (int64_t)(row + cta_row_offset) * params_qkv_stride_in_bytes_; + // Add the block index. + int idx; + if (HEADS_INTERLEAVED) { + idx = binfo.bidx * NUM_MATS + qkv_offset; + } else { + idx = (params.b * params.s * NUM_MATS + qkv_offset) * params.h + binfo.bidh; + } + // Assemble the final pointer. + qkv_ptr_ += row_offset + idx * VALID_BYTES_PER_ROW + col * BYTES_PER_LDG; + + // active threads + is_active_ = col < VALID_THREADS_PER_ROW; + } + + // Store data to shared memory. + template + inline __device__ void commit(Smem_tile& smem_tile) { + if (!USE_LDGSTS) { + smem_tile.store(fetch_); + } + } + + // Load data from memory. + template + inline __device__ void load(Smem_tile& smem_tile) { + void const* ptrs[LDGS]; +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + ptrs[ii] = qkv_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_; + } + if (USE_LDGSTS) { + smem_tile.store(ptrs, preds_); + } else { + fmha::ldg(fetch_, ptrs, preds_); + } + } + + // Load data from global memory, shared mem is not needed + inline __device__ void load() { + void const* ptrs[LDGS]; + if (is_active_) { +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + ptrs[ii] = qkv_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_; + } + fmha::ldg(fetch_, ptrs, preds_); + } + } + + // Move the pointer to the next location. + inline __device__ void move() { qkv_ptr_ += (int64_t)ROWS * params_qkv_stride_in_bytes_; } + + // The stride between rows for the QKV matrice. + int64_t const params_qkv_stride_in_bytes_; + // The pointer. + char const* qkv_ptr_; + // The register to store predicates. + uint32_t preds_[PRED_REGS]; + // The fetch registers. + uint4 fetch_[LDGS]; + // The active LDG threads + bool is_active_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace v1 +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/gmem_tile_qkv_packed.h b/csrc/fmha_v2/fmha/gmem_tile_qkv_packed.h new file mode 100644 index 0000000000..00797d0a01 --- /dev/null +++ b/csrc/fmha_v2/fmha/gmem_tile_qkv_packed.h @@ -0,0 +1,1307 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once +#include +#include + +namespace fmha { +namespace v2 { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Ldgsts_helper { + template + static inline __device__ void load(This* this_, Smem_tile& smem_tile, void const* (&ptrs)[LDGS], + uint32_t (&preds)[LDGS]) { + fmha::pack_predicates(this_->preds_, preds); + smem_tile.store(ptrs, this_->preds_); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Ldgsts_helper<0> { + template + static inline __device__ void load(This* this_, Smem_tile& smem_tile, void const* (&ptrs)[LDGS], + uint32_t (&preds)[LDGS]) { +#if 0 + fmha::pack_predicates(this_->preds_, preds); + fmha::ldg(this_->fetch_, ptrs, this_->preds_); +#else +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + this_->fetch_[ii] = make_uint4(0u, 0u, 0u, 0u); + } + // not packing predicates removes restrictions (e.g. FP16 384, 4 warps) + Ldg_functor fct(this_->fetch_, ptrs); +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + fct.ldgsts(ii, preds[ii]); + } +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits, + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The number of bits per element. + int BITS_PER_ELEMENT_, + // The number of rows of Q, K or V loaded by this tile. + int ROWS_, + // The number of columns (padded, e.g 64). + int COLS, + // The actual number of columns (unpadded, e.g 40) + int VALID_COLS_, + // Do we use LDGSTS? + bool USE_LDGSTS_, + // Are attention heads interleaved? + bool HEADS_INTERLEAVED, + // The number of matrices + int NUM_MATS = 3, + // Is sliding window attention used ? + bool SLIDING_WINDOW_ATTENTION = false> +struct Gmem_tile_qkv { + // The size of each LDG. + enum { BYTES_PER_LDG = 16 }; + + // The number of bits/bytes of element + enum { BITS_PER_ELEMENT = BITS_PER_ELEMENT_ }; + + enum { BYTES_PER_ELEMENT = BITS_PER_ELEMENT_ / 8 }; + + // The size of a row in bytes. + enum { BYTES_PER_ROW = COLS * BITS_PER_ELEMENT / 8 }; + + // The number of threads to load a "row" of the matrix. + enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDG }; + + // The valid size of a row in bytes (without paddings). + enum { VALID_COLS = VALID_COLS_ }; + + // The amount of bytes that are valid per row. + enum { VALID_BYTES_PER_ROW = VALID_COLS * BITS_PER_ELEMENT / 8 }; + + // The number of "rows" loaded per LDG. + enum { ROWS_PER_LDG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + + // The number of rows. + enum { ROWS = ROWS_ }; + + // The number of LDGs needed to load a chunk of the Q matrix. + enum { LDGS = fmha::Div_up::VALUE }; + + // The number of predicate registers. + enum { PRED_REGS = fmha::Compute_number_of_pred_regs::VALUE }; + + // Is it Hopper? + enum { + IS_HOPPER = std::is_same::value == true + }; + + // Make sure we use a single register to store predicates. Do not throw for Hopper for now. + static_assert(!USE_LDGSTS_ || PRED_REGS == 1 || IS_HOPPER, ""); + + // We do not use LDGSTS (for the moment). + enum { USE_LDGSTS = USE_LDGSTS_ }; + + // Ctor for bert::Fused_multihead_attention_params_v2 class + template + inline __device__ Gmem_tile_qkv(bert::Fused_multihead_attention_params_v2 const& params, + int qkv_offset, Block_info const& binfo, int tidx, + int cta_row_offset = 0, int cta_col_offset_in_bytes = 0) + : Gmem_tile_qkv(params.qkv_ptr, params.q_stride_in_bytes, params.d, params.dv, params.h, + qkv_offset, binfo, tidx, params.h_kv, cta_row_offset, + cta_col_offset_in_bytes) {} + + // Ctor for other param classes (such as Qkv_params in train_ops) + template + inline __device__ Gmem_tile_qkv(Params const& params, int qkv_offset, Block_info const& binfo, + int tidx, int cta_row_offset = 0, int cta_col_offset_in_bytes = 0) + : Gmem_tile_qkv(params.qkv_ptr, params.q_stride_in_bytes, params.d, params.dv, params.h, + qkv_offset, binfo, tidx, cta_row_offset, cta_col_offset_in_bytes) {} + + // Ctor. + template + inline __device__ Gmem_tile_qkv(void* qkv_ptr, size_t qkv_stride_in_bytes, int d, int dv, + int num_heads, int qkv_offset, Block_info const& binfo, int tidx, + int num_kv_heads = 0, int cta_row_offset = 0, + int cta_col_offset_in_bytes = 0) + : params_qkv_stride_in_bytes_(qkv_stride_in_bytes), + actual_seqlen_(binfo.actual_seqlen), + qkv_ptr_(reinterpret_cast(qkv_ptr)) { + // Compute the position in the sequence (within the CTA for the moment). + int row = tidx / THREADS_PER_ROW; + // Compute the position of the thread in the row. + int col = tidx % THREADS_PER_ROW; + + // We must store the value to update the predicates in "load". + row_ = row; + // Do not load/store if the thread is in the padded area + col_in_bytes_ = cta_col_offset_in_bytes + col * BYTES_PER_LDG; + + // The row offset in the batched GEMM. For each seq element, we store QKV in that order. + int64_t row_offset = (int64_t)(row + cta_row_offset) * params_qkv_stride_in_bytes_; + // Add the byte index. + int64_t idx; + + // Both MQA and GQA will use non HEADS_INTERLEAVED layout + if (num_kv_heads < num_heads) { + int const head_id = binfo.bidh; + int const kv_head_id = binfo.bidh / (num_heads / num_kv_heads); + // QKV layout [b, s, [q_hd, k_h'd, v_h'd]] + idx = binfo.sum_s * params_qkv_stride_in_bytes_; + if (qkv_offset == 0) { // Q tensor + idx += head_id * VALID_BYTES_PER_ROW; + } else if (qkv_offset == 1) { // K tensor + idx += (num_heads + kv_head_id) * VALID_BYTES_PER_ROW; + } else if (qkv_offset == 2) { // V tensor + /* When qkv_offset == 2, this is an instance of Gmem_tile_v defined in Kernel_traits: + using Gmem_tile_v = Gmem_tile_v_; + the 6th template argument is VALID_DV instead of VALID_D. + Thus, here VALID_COLS equals VALID_DV, and + VALID_BYTES_PER_ROW equals VALID_DV * BYTES_PER_ELEMENT, + and `kv_head_id * dv * BYTES_PER_ELEMENT` can be optimized to + `kv_head_id * VALID_BYTES_PER_ROW`. */ + idx += + (num_heads + num_kv_heads) * d * BYTES_PER_ELEMENT + kv_head_id * VALID_BYTES_PER_ROW; + } + } else if (HEADS_INTERLEAVED) { + // [b, s, h, [q_d, k_d, v_d]] aka bsh3d + // bidx = sum_s * params.h + bidh; + idx = (binfo.bidx * (2 * d + dv) + qkv_offset * d) * BYTES_PER_ELEMENT; + } else { + // [b, s, [q_hd, k_hd, v_hd]] aka bs3hd + idx = binfo.sum_s * params_qkv_stride_in_bytes_ + + qkv_offset * num_heads * d * BYTES_PER_ELEMENT + binfo.bidh * VALID_BYTES_PER_ROW; + } + + // Assemble the final pointer. + qkv_ptr_ += row_offset + idx + col_in_bytes_; + + // Take the CTA offset to modify the sequence length. + actual_seqlen_ -= cta_row_offset; + + // Set the initial seq_len and qkv_offset in case of reinterating + actual_seqlen_init_ = actual_seqlen_; + qkv_ptr_init_ = qkv_ptr_; + } + + // Store data to shared memory. + template + inline __device__ void commit(Smem_tile& smem_tile) { + if (!USE_LDGSTS) { + smem_tile.store(fetch_); + } + } + + // Load data from memory. + template + inline __device__ void load(Smem_tile& smem_tile) { + uint32_t preds[LDGS]; +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + preds[ii] = row_ + ii * (int)ROWS_PER_LDG < min((int)ROWS, actual_seqlen_); + preds[ii] &= col_in_bytes_ < VALID_BYTES_PER_ROW; + } + + // Prepare the load pointers. + void const* ptrs[LDGS]; +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + ptrs[ii] = qkv_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_; + } + + // Trigger LDGSTS or the LDGs. + // The predicates protect against out-of-bound access in rows and cols + Ldgsts_helper::load(this, smem_tile, ptrs, preds); + } + + // Load data from memory. + inline __device__ void load() { + uint32_t preds[LDGS]; +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + preds[ii] = row_ + ii * (int)ROWS_PER_LDG < min((int)ROWS, actual_seqlen_); + } + + // Prepare the load pointers. + void const* ptrs[LDGS]; +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + ptrs[ii] = qkv_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_; + } + + // Trigger the LDGs. + if (col_in_bytes_ < VALID_BYTES_PER_ROW) { + fmha::pack_predicates(preds_, preds); + fmha::ldg(fetch_, ptrs, preds_); + } else { +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + fetch_[ii] = make_uint4(0u, 0u, 0u, 0u); + } + } + } + + // Move the pointer to the next row location. + inline __device__ void move(int const steps = 1) { + qkv_ptr_ += (int64_t)ROWS * params_qkv_stride_in_bytes_ * steps; + actual_seqlen_ -= (int)ROWS * steps; + } + + // Move the pointer to the next row location by the offset (not step). + inline __device__ void move_by_offset(int const offset) { + qkv_ptr_ = qkv_ptr_init_ + (int64_t)offset * params_qkv_stride_in_bytes_; + actual_seqlen_ = actual_seqlen_init_ - (int)offset; + } + + // Move the pointer to the next column location + inline __device__ void move_col(int const steps = 1) { + qkv_ptr_ += (int64_t)COLS * (BITS_PER_ELEMENT / 8) * steps; + // Update col_in_bytes_ to ensure load predicates work + col_in_bytes_ += THREADS_PER_ROW * BYTES_PER_LDG * steps; + } + + inline __device__ void reset() { + qkv_ptr_ = qkv_ptr_init_; + actual_seqlen_ = actual_seqlen_init_; + } + + // Rewind the pointer back to previous column location + inline __device__ void rewind_col(int const steps) { + qkv_ptr_ -= COLS * (BITS_PER_ELEMENT / 8) * steps; + // Update col_in_bytes_ to ensure load predicates work + col_in_bytes_ -= THREADS_PER_ROW * BYTES_PER_LDG * steps; + } + + inline __device__ void move_to(int const step) { + qkv_ptr_ = qkv_ptr_init_ + (int64_t)ROWS * params_qkv_stride_in_bytes_ * step; + actual_seqlen_ = actual_seqlen_init_ - (int)ROWS * step; + } + + // Store data to memory. + inline __device__ void store(uint4 const (&data)[LDGS]) { +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + char* ptr = qkv_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_; + if (((row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen_)) && + col_in_bytes_ < VALID_BYTES_PER_ROW /*TODO: double check*/) { + fmha::stg(ptr, data[ii]); + } + } + } + + // The stride between rows for the QKV matrice. + int64_t params_qkv_stride_in_bytes_; + // The pointer. + char* qkv_ptr_; + char* qkv_ptr_init_; + // The register to store predicates. + uint32_t preds_[PRED_REGS]; + // The fetch registers. + uint4 fetch_[LDGS]; + // Keep track of the row and col the thread is processing as we move the tile. + int row_; + int col_in_bytes_; + // The sequence length. + int actual_seqlen_; + int actual_seqlen_init_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// We expect the Q/K/V layout to be [B, S, H, D] with variable sequence length support. +template < + // The instruction traits. + typename Traits, + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The number of bits per element. + int BITS_PER_ELEMENT_, + // The number of rows of Q, K or V loaded by this tile. + int ROWS_, + // The number of columns (padded, e.g 64). + int COLS, + // The actual number of columns (unpadded, e.g 40) + int VALID_COLS_, + // Do we use LDGSTS? + bool USE_LDGSTS_, + // Are attention heads interleaved? (not used) + bool HEADS_INTERLEAVED = false, + // The number of matrices (not used) + int NUM_MATS = 1, + // Is sliding window attention used ? + bool SLIDING_WINDOW_ATTENTION = false> +struct Gmem_tile_q_k_v { + // The size of each LDG. + enum { BYTES_PER_LDG = 16 }; + + // The number of bits/bytes of element + enum { BITS_PER_ELEMENT = BITS_PER_ELEMENT_ }; + + enum { BYTES_PER_ELEMENT = BITS_PER_ELEMENT_ / 8 }; + + // The size of a row in bytes. + enum { BYTES_PER_ROW = COLS * BITS_PER_ELEMENT / 8 }; + + // The number of threads to load a "row" of the matrix. + enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDG }; + + // The valid size of a row in bytes (without paddings). + enum { VALID_COLS = VALID_COLS_ }; + + // The amount of bytes that are valid per row. + enum { VALID_BYTES_PER_ROW = VALID_COLS * BITS_PER_ELEMENT / 8 }; + + // The number of "rows" loaded per LDG. + enum { ROWS_PER_LDG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + + // The number of rows. + enum { ROWS = ROWS_ }; + + // The number of LDGs needed to load a chunk of the Q matrix. + enum { LDGS = fmha::Div_up::VALUE }; + + // The number of predicate registers. + enum { PRED_REGS = fmha::Compute_number_of_pred_regs::VALUE }; + + // Is it Hopper? + enum { + IS_HOPPER = std::is_same::value == true + }; + + // Make sure we use a single register to store predicates. Do not throw for Hopper for now. + static_assert(!USE_LDGSTS_ || PRED_REGS == 1 || IS_HOPPER, ""); + + // We do not use LDGSTS (for the moment). + enum { USE_LDGSTS = USE_LDGSTS_ }; + + // Ctor + // qkv_offset: 0 for Q, 1 for K, 2 for V + template + inline __device__ Gmem_tile_q_k_v(bert::Fused_multihead_attention_params_v2 const& params, + int qkv_offset, Block_info const& binfo, int tidx, + int cta_row_offset = 0, int cta_col_offset_in_bytes = 0) { + int seq_offset = 0; + if (qkv_offset == 0) { + // Q tensor + params_q_k_v_stride_in_bytes_ = params.q_stride_in_bytes; + q_k_v_ptr_ = reinterpret_cast(params.q_ptr); + actual_seqlen_ = binfo.actual_q_seqlen; + seq_offset = binfo.sum_s; + } else if (qkv_offset == 1) { + // K tensor + params_q_k_v_stride_in_bytes_ = params.k_stride_in_bytes; + q_k_v_ptr_ = reinterpret_cast(params.k_ptr); + actual_seqlen_ = binfo.actual_kv_seqlen; + seq_offset = binfo.sum_s_kv; + } else if (qkv_offset == 2) { + // V tensor + params_q_k_v_stride_in_bytes_ = params.v_stride_in_bytes; + q_k_v_ptr_ = reinterpret_cast(params.v_ptr); + actual_seqlen_ = binfo.actual_kv_seqlen; + seq_offset = binfo.sum_s_kv; + } + + // Compute the position in the sequence (within the CTA for the moment). + int row = tidx / THREADS_PER_ROW; + // Compute the position of the thread in the row. + int col = tidx % THREADS_PER_ROW; + + // We must store the value to update the predicates in "load". + row_ = row; + // Do not load/store if the thread is in the padded area + col_in_bytes_ = cta_col_offset_in_bytes + col * BYTES_PER_LDG; + + // The row offset in the batched GEMM, including the sequence offset. + int64_t row_offset = + (int64_t)(row + cta_row_offset + seq_offset) * params_q_k_v_stride_in_bytes_; + // Add the head index. + int64_t idx = binfo.bidh; + + // Assemble the final pointer. + q_k_v_ptr_ += row_offset + idx * VALID_BYTES_PER_ROW + col_in_bytes_; + + // Take the CTA offset to modify the sequence length. + actual_seqlen_ -= cta_row_offset; + + // Set the initial seq_len and qkv_offset in case of reinterating + actual_seqlen_init_ = actual_seqlen_; + q_k_v_ptr_init_ = q_k_v_ptr_; + } + + // Store data to shared memory. + template + inline __device__ void commit(Smem_tile& smem_tile) { + if (!USE_LDGSTS) { + smem_tile.store(fetch_); + } + } + + // Load data from memory. + template + inline __device__ void load(Smem_tile& smem_tile) { + uint32_t preds[LDGS]; +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + preds[ii] = row_ + ii * (int)ROWS_PER_LDG < min((int)ROWS, actual_seqlen_); + preds[ii] &= col_in_bytes_ < VALID_BYTES_PER_ROW; + } + + // Prepare the load pointers. + void const* ptrs[LDGS]; +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + ptrs[ii] = q_k_v_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_q_k_v_stride_in_bytes_; + } + + // Trigger LDGSTS or the LDGs. + // The predicates protect against out-of-bound access in rows and cols + Ldgsts_helper::load(this, smem_tile, ptrs, preds); + } + + // Move the pointer to the next row location. + inline __device__ void move(int const steps = 1) { + q_k_v_ptr_ += (int64_t)ROWS * params_q_k_v_stride_in_bytes_ * steps; + actual_seqlen_ -= (int)ROWS * steps; + } + + // Move the pointer to the next row location by the offset (not step). + inline __device__ void move_by_offset(int const offset) { + q_k_v_ptr_ = q_k_v_ptr_init_ + (int64_t)offset * params_q_k_v_stride_in_bytes_; + actual_seqlen_ = actual_seqlen_init_ - (int)offset; + } + + // Move the pointer to the next column location + inline __device__ void move_col() { + q_k_v_ptr_ += (int64_t)COLS * (BITS_PER_ELEMENT / 8); + // Update col_in_bytes_ to ensure load predicates work + col_in_bytes_ += THREADS_PER_ROW * BYTES_PER_LDG; + } + + // Rewind the pointer back to previous column location + inline __device__ void rewind_col(int const steps) { + q_k_v_ptr_ -= COLS * (BITS_PER_ELEMENT / 8) * steps; + // Update col_in_bytes_ to ensure load predicates work + col_in_bytes_ -= THREADS_PER_ROW * BYTES_PER_LDG * steps; + } + + // Move the pointer to the specified step. + inline __device__ void move_to(int const step) { + q_k_v_ptr_ = q_k_v_ptr_init_ + (int64_t)ROWS * params_q_k_v_stride_in_bytes_ * step; + actual_seqlen_ = actual_seqlen_init_ - (int)ROWS * step; + } + + inline __device__ void reset() { + q_k_v_ptr_ = q_k_v_ptr_init_; + actual_seqlen_ = actual_seqlen_init_; + } + + // The stride between rows for the Q/K/V matrice. + int64_t params_q_k_v_stride_in_bytes_; + // The pointer. + char* q_k_v_ptr_; + char* q_k_v_ptr_init_; + // The register to store predicates. + uint32_t preds_[PRED_REGS]; + // The fetch registers. + uint4 fetch_[LDGS]; + // Keep track of the row and col the thread is processing as we move the tile. + int row_; + int64_t col_in_bytes_; + // The sequence length. + int actual_seqlen_; + int actual_seqlen_init_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Shape [B, S, 2, H, D] where S can be variable sequence length. +template < + // The instruction traits. + typename Traits, + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The number of bits per element. + int BITS_PER_ELEMENT_, + // The number of rows of Q, K or V loaded by this tile. + int ROWS_, + // The number of columns (padded, e.g 64). + int COLS, + // The actual number of columns (unpadded, e.g 40) + int VALID_COLS_, + // Do we use LDGSTS? + bool USE_LDGSTS_, + // Are attention heads interleaved? (Not used) + bool HEADS_INTERLEAVED, + // The number of matrices (Not used) + int NUM_MATS = 2, + // Is sliding window attention used ? + bool SLIDING_WINDOW_ATTENTION = false> +struct Gmem_tile_contiguous_kv { + // The size of each LDG. + enum { BYTES_PER_LDG = 16 }; + + // The number of bits/bytes of element + enum { BITS_PER_ELEMENT = BITS_PER_ELEMENT_ }; + + enum { BYTES_PER_ELEMENT = BITS_PER_ELEMENT_ / 8 }; + + // The size of a row in bytes. + enum { BYTES_PER_ROW = COLS * BITS_PER_ELEMENT / 8 }; + + // The number of threads to load a "row" of the matrix. + enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDG }; + + // The valid size of a row in bytes (without paddings). + enum { VALID_COLS = VALID_COLS_ }; + + // The amount of bytes that are valid per row. + enum { VALID_BYTES_PER_ROW = VALID_COLS * BITS_PER_ELEMENT / 8 }; + + // The number of "rows" loaded per LDG. + enum { ROWS_PER_LDG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + + // The number of rows. + enum { ROWS = ROWS_ }; + + // The number of LDGs needed to load a chunk of the Q matrix. + enum { LDGS = fmha::Div_up::VALUE }; + + // The number of predicate registers. + enum { PRED_REGS = fmha::Compute_number_of_pred_regs::VALUE }; + + // Is it Hopper? + enum { + IS_HOPPER = std::is_same::value == true + }; + + // Make sure we use a single register to store predicates. Do not throw for Hopper for now. + static_assert(!USE_LDGSTS_ || PRED_REGS == 1 || IS_HOPPER, ""); + + // We do not use LDGSTS (for the moment). + enum { USE_LDGSTS = USE_LDGSTS_ }; + + // Ctor for bert::Fused_multihead_attention_params_v2 class + template + inline __device__ Gmem_tile_contiguous_kv(bert::Fused_multihead_attention_params_v2 const& params, + int qkv_offset, // q = 0, k = 1, v = 2. + Block_info const& binfo, int tidx, + int cta_row_offset = 0, int cta_col_offset_in_bytes = 0) + : Gmem_tile_contiguous_kv(params.kv_ptr, params.k_stride_in_bytes, params.h_kv, + params.h_q_per_kv, qkv_offset, binfo, tidx, cta_row_offset, + cta_col_offset_in_bytes) {} + + // Ctor. + template + inline __device__ Gmem_tile_contiguous_kv(void* kv_ptr, size_t kv_stride_in_bytes, + int num_kv_heads, int head_group_size, int qkv_offset, + Block_info const& binfo, int tidx, + int cta_row_offset = 0, int cta_col_offset_in_bytes = 0) + : params_kv_stride_in_bytes_(kv_stride_in_bytes), + actual_seqlen_(binfo.actual_kv_seqlen), + kv_ptr_(reinterpret_cast(kv_ptr)) { + // Compute the position in the sequence (within the CTA for the moment). + int row = tidx / THREADS_PER_ROW; + // Compute the position of the thread in the row. + int col = tidx % THREADS_PER_ROW; + + // We must store the value to update the predicates in "load". + row_ = row; + // Do not load/store if the thread is in the padded area + col_in_bytes_ = cta_col_offset_in_bytes + col * BYTES_PER_LDG; + + // The row offset in the batched GEMM. + int64_t row_offset = (int64_t)(row + cta_row_offset) * params_kv_stride_in_bytes_; + // [b, s, 2, h_kv, d]. + int64_t idx = + (binfo.sum_s_kv * 2 + qkv_offset - 1) * num_kv_heads + (binfo.bidh / head_group_size); + + // Assemble the final pointer. + kv_ptr_ += row_offset + idx * VALID_BYTES_PER_ROW + col_in_bytes_; + + // Take the CTA offset to modify the sequence length. + actual_seqlen_ -= cta_row_offset; + + // Set the initial seq_len and qkv_offset in case of reinterating + actual_seqlen_init_ = actual_seqlen_; + kv_ptr_init_ = kv_ptr_; + } + + // Store data to shared memory. + template + inline __device__ void commit(Smem_tile& smem_tile) { + if (!USE_LDGSTS) { + smem_tile.store(fetch_); + } + } + + // Load data from memory. + template + inline __device__ void load(Smem_tile& smem_tile) { + uint32_t preds[LDGS]; +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + preds[ii] = row_ + ii * (int)ROWS_PER_LDG < min((int)ROWS, actual_seqlen_); + preds[ii] &= col_in_bytes_ < VALID_BYTES_PER_ROW; + } + + // Prepare the load pointers. + void const* ptrs[LDGS]; +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + ptrs[ii] = kv_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_kv_stride_in_bytes_; + } + + // Trigger LDGSTS or the LDGs. + // The predicates protect against out-of-bound access in rows and cols + Ldgsts_helper::load(this, smem_tile, ptrs, preds); + } + + // Load data from memory. + inline __device__ void load() { + uint32_t preds[LDGS]; +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + preds[ii] = row_ + ii * (int)ROWS_PER_LDG < min((int)ROWS, actual_seqlen_); + } + + // Prepare the load pointers. + void const* ptrs[LDGS]; +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + ptrs[ii] = kv_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_kv_stride_in_bytes_; + } + + // Trigger the LDGs. + if (col_in_bytes_ < VALID_BYTES_PER_ROW) { + fmha::pack_predicates(preds_, preds); + fmha::ldg(fetch_, ptrs, preds_); + } else { +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + fetch_[ii] = make_uint4(0u, 0u, 0u, 0u); + } + } + } + + // Move the pointer to the next row location. + inline __device__ void move(int const steps = 1) { + kv_ptr_ += (int64_t)ROWS * params_kv_stride_in_bytes_ * steps; + actual_seqlen_ -= (int)ROWS * steps; + } + + // Move the pointer to the next row location by the offset (not step). + inline __device__ void move_by_offset(int const offset) { + kv_ptr_ = kv_ptr_init_ + (int64_t)offset * params_kv_stride_in_bytes_; + actual_seqlen_ = actual_seqlen_init_ - (int)offset; + } + + // Move the pointer to the next column location + inline __device__ void move_col(int const steps = 1) { + kv_ptr_ += (int64_t)COLS * (BITS_PER_ELEMENT / 8) * steps; + // Update col_in_bytes_ to ensure load predicates work + col_in_bytes_ += THREADS_PER_ROW * BYTES_PER_LDG * steps; + } + + inline __device__ void reset() { + kv_ptr_ = kv_ptr_init_; + actual_seqlen_ = actual_seqlen_init_; + } + + // Rewind the pointer back to previous column location + inline __device__ void rewind_col(int const steps) { + kv_ptr_ -= COLS * (BITS_PER_ELEMENT / 8) * steps; + // Update col_in_bytes_ to ensure load predicates work + col_in_bytes_ -= THREADS_PER_ROW * BYTES_PER_LDG * steps; + } + + inline __device__ void move_to(int const step) { + kv_ptr_ = kv_ptr_init_ + (int64_t)ROWS * params_kv_stride_in_bytes_ * step; + actual_seqlen_ = actual_seqlen_init_ - (int)ROWS * step; + } + + // Store data to memory. + inline __device__ void store(uint4 const (&data)[LDGS]) { +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + char* ptr = kv_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_kv_stride_in_bytes_; + if (((row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen_)) && + col_in_bytes_ < VALID_BYTES_PER_ROW /*TODO: double check*/) { + fmha::stg(ptr, data[ii]); + } + } + } + + // The stride between rows for the QKV matrice. + int64_t params_kv_stride_in_bytes_; + // The pointer. + char* kv_ptr_; + char* kv_ptr_init_; + // The register to store predicates. + uint32_t preds_[PRED_REGS]; + // The fetch registers. + uint4 fetch_[LDGS]; + // Keep track of the row and col the thread is processing as we move the tile. + int row_; + int col_in_bytes_; + // The sequence length. + int actual_seqlen_; + int actual_seqlen_init_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// We expect the paged KV layout to be blocks of indices with shape of [B, 2, Blocks_per_Seq], +// and the indice tells the memory distance to the pool ptr in global memory. + +template < + // The instruction traits. + typename Traits, + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The number of bits per element. + int BITS_PER_ELEMENT_, + // The number of rows of Q, K or V loaded by this tile. + int ROWS_, + // The number of columns (padded, e.g 64). + int COLS, + // The actual number of columns (unpadded, e.g 40) + int VALID_COLS_, + // Do we use LDGSTS? + bool USE_LDGSTS_, + // Are attention heads interleaved? (not used) + bool HEADS_INTERLEAVED = false, + // The number of matrices (not used) + int NUM_MATS = 2, + // Is sliding window attention used ? + bool SLIDING_WINDOW_ATTENTION_ = false> +struct Gmem_tile_paged_kv { + // The size of each LDG. + enum { BYTES_PER_LDG = 16 }; + + // The number of bits/bytes of element + enum { BITS_PER_ELEMENT = BITS_PER_ELEMENT_ }; + + enum { BYTES_PER_ELEMENT = BITS_PER_ELEMENT_ / 8 }; + + // The size of a row in bytes. + enum { BYTES_PER_ROW = COLS * BITS_PER_ELEMENT / 8 }; + + // The number of threads to load a "row" of the matrix. + enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDG }; + + // The valid size of a row in bytes (without paddings). + enum { VALID_COLS = VALID_COLS_ }; + + // The amount of bytes that are valid per row. + enum { VALID_BYTES_PER_ROW = VALID_COLS * BITS_PER_ELEMENT / 8 }; + + // The number of "rows" loaded per LDG. + enum { ROWS_PER_LDG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + + // The number of rows. + enum { ROWS = ROWS_ }; + + // The number of LDGs needed to load a chunk of the Q matrix. + enum { LDGS = fmha::Div_up::VALUE }; + + // The number of predicate registers. + enum { PRED_REGS = fmha::Compute_number_of_pred_regs::VALUE }; + + // Is sliding window attention used ? + enum { SLIDING_WINDOW_ATTENTION = SLIDING_WINDOW_ATTENTION_ }; + + // Is it Hopper? + enum { + IS_HOPPER = std::is_same::value == true + }; + + // Make sure we use a single register to store predicates. Do not throw for Hopper for now. + static_assert(!USE_LDGSTS_ || PRED_REGS == 1 || IS_HOPPER, ""); + + // We do not use LDGSTS (for the moment). + enum { USE_LDGSTS = USE_LDGSTS_ }; + + // Ctor. + template + inline __device__ Gmem_tile_paged_kv(bert::Fused_multihead_attention_params_v2 const& params, + int qkv_offset, // q = 0, k = 1, v = 2. + Block_info const& binfo, int tidx, int cta_row_offset = 0, + int cta_col_offset_in_bytes = 0) + : actual_seqlen_(binfo.actual_seqlen), + past_seqlen_(binfo.actual_seqlen - binfo.actual_q_seqlen), + sliding_window_size_(params.sliding_window_size), + paged_kv_log2_block_size_(params.paged_kv_cache.mTokensPerBlockLog2), + paged_kv_block_pool_ptr_(reinterpret_cast(params.paged_kv_cache.mPoolPtr)), + paged_kv_global_block_offsets_(params.paged_kv_cache.mBlockOffsets), + params_kv_block_size_in_bytes_(params.paged_kv_cache.mBytesPerBlock) { + // Handle Paged KV with shape [S, Dh], by offsetting it to the target batch. + int32_t const paged_kv_block_offset = + (binfo.bidb * 2 + qkv_offset - 1) * params.paged_kv_cache.mMaxBlocksPerSeq; + paged_kv_global_block_offsets_ += paged_kv_block_offset; + + // Compute the position in the sequence (within the CTA for the moment). + int row = tidx / THREADS_PER_ROW; + // Compute the position of the thread in the row. + int col = tidx % THREADS_PER_ROW; + + // We must store the value to update the predicates in "load". + row_ = row; + // Do not load/store if the thread is in the padded area + col_in_bytes_ = cta_col_offset_in_bytes + col * BYTES_PER_LDG; + + int64_t kv_stride_in_bytes = + qkv_offset == 1 ? params.k_stride_in_bytes : params.v_stride_in_bytes; + // The head offset. + head_stride_in_bytes_ = (int64_t)(binfo.bidh / params.h_q_per_kv) * kv_stride_in_bytes; + // When V is padded (like MLA), we cannot use VALID_BYTES_PER_ROW + token_stride_in_bytes_ = kv_stride_in_bytes >> paged_kv_log2_block_size_; + + // Take the CTA offset to modify the sequence length. + // Actually we don't need that for flash attention. + actual_seqlen_ -= cta_row_offset; + } + + // Store data to shared memory. + template + inline __device__ void commit(Smem_tile& smem_tile) { + if (!USE_LDGSTS) { + smem_tile.store(fetch_); + } + } + + // Load data from memory. + template + inline __device__ void load(Smem_tile& smem_tile) { + // Prepare the predicates. + uint32_t preds[LDGS]; + // Prepare the load pointers. + void const* ptrs[LDGS]; + + // Offset for the new paged kv pointer. + uint64_t const head_col_in_bytes = head_stride_in_bytes_ + col_in_bytes_; + +// Update paged_kv ptr for each LDG (reuse is possible). +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + int row_idx = row_ + ii * (int)ROWS_PER_LDG; + int paged_kv_block_idx = (row_idx >> paged_kv_log2_block_size_); + char const* local_kv_ptr = reinterpret_cast( + paged_kv_block_pool_ptr_ + + params_kv_block_size_in_bytes_ * paged_kv_global_block_offsets_[paged_kv_block_idx]); + + // Predicates. + // TODO: do we need to make sure row_idx < ROWS ? + preds[ii] = row_idx < actual_seqlen_; + preds[ii] &= col_in_bytes_ < VALID_BYTES_PER_ROW; + + // Pointers. + int row_idx_in_block = row_idx & ((1 << paged_kv_log2_block_size_) - 1); + ptrs[ii] = + local_kv_ptr + head_col_in_bytes + (int64_t)row_idx_in_block * token_stride_in_bytes_; + } + + // Trigger LDGSTS or the LDGs. + // The predicates protect against out-of-bound access in rows and cols + Ldgsts_helper::load(this, smem_tile, ptrs, preds); + } + + // Move the pointer to the next row location. + inline __device__ void move() { row_ += ROWS; } + + // Move the pointer to the next row location by the offset (not step). + inline __device__ void move_by_offset(int const offset) { row_ += offset; } + + // Move the pointer to the next column location + inline __device__ void move_col() { col_in_bytes_ += THREADS_PER_ROW * BYTES_PER_LDG; } + + // Rewind the pointer back to previous column location + inline __device__ void rewind_col(int const steps) { + // Update col_in_bytes_ to ensure load predicates work + col_in_bytes_ -= THREADS_PER_ROW * BYTES_PER_LDG * steps; + } + + // The stride between rows for the KV matrice. + int64_t params_kv_block_size_in_bytes_; + // The paged cache pool pointer. + char* paged_kv_block_pool_ptr_; + // The paged block offsets. + int32_t* paged_kv_global_block_offsets_; + // The paged block size. + int paged_kv_log2_block_size_; + // The register to store predicates. + uint32_t preds_[PRED_REGS]; + // The fetch registers. + uint4 fetch_[LDGS]; + // Keep track of the row and col the thread is processing as we move the tile. + int row_; + int64_t col_in_bytes_; + // Keep track of the head offset. + int64_t head_stride_in_bytes_; + // // for DeepSeek MLA, the stride of V tokens != VALID_BYTES_PER_ROW + int32_t token_stride_in_bytes_; + // The sequence length. + int actual_seqlen_; + // The past sequence length (kv_seqlen - q_seqlen) considering chunked context. + int past_seqlen_; + // The sliding attention window size. + int sliding_window_size_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits, + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The number of bits per element. + int BITS_PER_ELEMENT, + // The number of rows of Q loaded by this tile. + int ROWS_, + // The number of columns. + int COLS, + // Do we use LDGSTS? + bool USE_LDGSTS_, + // Are attention heads interleaved? + bool HEADS_INTERLEAVED, + // The number of matrices + int NUM_MATS = 1> +struct Gmem_tile_q_kv { + // The size of each LDG. + enum { BYTES_PER_LDG = 16 }; + + // The padded to the next power of 2 number of columns + enum { COLS_PADDED = Next_power_of_two::VALUE }; + + // The padded size of a row in bytes. + enum { BYTES_PER_ROW_PADDED = COLS_PADDED * BITS_PER_ELEMENT / 8 }; + + // The size of a row in bytes. + enum { BYTES_PER_ROW = COLS * BITS_PER_ELEMENT / 8 }; + + // The number of threads to load a padded "row" of the matrix. + enum { THREADS_PER_ROW_PADDED = BYTES_PER_ROW_PADDED / BYTES_PER_LDG }; + + // The number of threads to load a "row" of the matrix. + enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDG }; + + // The number of "rows" loaded per LDG. + enum { ROWS_PER_LDG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW_PADDED }; + + // The number of rows. + enum { ROWS = ROWS_ }; + + // The number of LDGs needed to load a chunk of the Q matrix. + enum { LDGS = fmha::Div_up::VALUE }; + + // The number of predicate registers. + enum { PRED_REGS = fmha::Compute_number_of_pred_regs::VALUE }; + + // Is it Hopper? + enum { + IS_HOPPER = std::is_same::value == true + }; + + // Make sure we use a single register to store predicates. Do not throw for Hopper for now. + static_assert(!USE_LDGSTS_ || PRED_REGS == 1 || IS_HOPPER, ""); + + // We do not use LDGSTS (for the moment). + enum { USE_LDGSTS = USE_LDGSTS_ }; + + // Ctor. + template + inline __device__ Gmem_tile_q_kv(Params const& params, int offset, Block_info const& binfo, + int tidx, int cta_row_offset = 0) + : params_stride_in_bytes_(params.stride_in_bytes), + actual_seqlen_(binfo.actual_seqlen), + ptr_(reinterpret_cast(params.ptr)) { + // Compute the position in the sequence (within the CTA for the moment). + int row = tidx / THREADS_PER_ROW_PADDED; + // Compute the position of the thread in the row. + int col = tidx % THREADS_PER_ROW_PADDED; + + // We must store the value to update the predicates in "load". + row_ = row; + // Mask for predicate if the channels are in the padded area + int const bytes_per_row_non_padded = params.d * BITS_PER_ELEMENT / 8; + mask_ = col < bytes_per_row_non_padded / BYTES_PER_LDG; + + // The row offset in the batched GEMM. For each seq element, we store QKV in that order. + int64_t row_offset = (int64_t)(row + cta_row_offset) * params.stride_in_bytes; + // Add the block index. + int64_t idx; + if (HEADS_INTERLEAVED) { + idx = binfo.bidx * NUM_MATS + offset; + } else { + idx = (binfo.sum_s * NUM_MATS + offset) * params.h + binfo.bidh; + } + // Assemble the final pointer. + ptr_ += row_offset + idx * bytes_per_row_non_padded + col * BYTES_PER_LDG; + + // Take the CTA offset to modify the sequence length. + actual_seqlen_ -= cta_row_offset; + } + + // Store data to shared memory. + template + inline __device__ void commit(Smem_tile& smem_tile) { + if (!USE_LDGSTS) { + smem_tile.store(fetch_); + } + } + + // Load data from memory. + template + inline __device__ void load(Smem_tile& smem_tile) { + uint32_t preds[LDGS]; +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + preds[ii] = (row_ + ii * (int)ROWS_PER_LDG < min((int)ROWS, actual_seqlen_)) && mask_; + } + + // Prepare the load pointers. + void const* ptrs[LDGS]; +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + ptrs[ii] = ptr_ + (int64_t)ii * ROWS_PER_LDG * params_stride_in_bytes_; + } + + // Trigger LDGSTS or the LDGs. + Ldgsts_helper::load(this, smem_tile, ptrs, preds); + } + + inline __device__ void move(int const steps = 1) { + ptr_ += (int64_t)ROWS * params_stride_in_bytes_ * steps; + actual_seqlen_ -= (int)ROWS * steps; + } + + // Store data to memory. + inline __device__ void store(uint4 const (&data)[LDGS]) { +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + char* ptr = ptr_ + (int64_t)ii * ROWS_PER_LDG * params_stride_in_bytes_; + if ((row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen_)) { + fmha::stg(ptr, data[ii]); + } + } + } + + // The stride between rows for the matrix. + int64_t params_stride_in_bytes_; + // The pointer. + char* ptr_; + // The register to store predicates. + uint32_t preds_[PRED_REGS]; + // The fetch registers. + uint4 fetch_[LDGS]; + // Keep track of the row and col the thread is processing as we move the tile. + int row_; + // Keep track of predicate state that depends only on the initialization state. + int mask_; + // The sequence length. + int actual_seqlen_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits, + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The number of bits per element. + int BITS_PER_ELEMENT, + // The number of rows of Q, K or V loaded by this tile. + int ROWS_, + // The number of columns. + int COLS, + // Do we use LDGSTS? + bool USE_LDGSTS_> +struct Gmem_tile_qkv_interleaved { + // The vectorization width for NC/32HW32. + enum { VEC = 32 }; + + // The size of each LDG. + enum { BYTES_PER_LDG = 16 }; + + // The size of a row in bytes. + enum { BYTES_PER_ROW = VEC * BITS_PER_ELEMENT / 8 }; + + // DEBUG. + static_assert(BYTES_PER_ROW == 32, ""); + + // END OF DEBUG. + + // The number of threads to load a "row" of the matrix. + enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDG }; + + // DEBUG. + static_assert(THREADS_PER_ROW == 2, ""); + + // END OF DEBUG. + + // The number of "rows" loaded per LDG. + enum { ROWS_PER_LDG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + + // The number of slices. It is either 1 for DIM_PER_HEAD == 32 and 2 for DIM_PER_HEAD == 64. + enum { NUM_SLICES = COLS / VEC }; + + // DEBUG. + static_assert(NUM_SLICES == 1 || NUM_SLICES == 2, ""); + + // END OF DEBUG. + + // The number of rows in a slice. + enum { ROWS = ROWS_ }; + + // The number of LDGs needed to load a chunk of the Q matrix. + enum { LDGS = fmha::Div_up::VALUE }; + + // The number of predicate registers. + enum { PRED_REGS = fmha::Compute_number_of_pred_regs::VALUE }; + + // Make sure we use a single register to store predicates. + static_assert(PRED_REGS == 1, ""); + + // Do we use LDGSTS on Ampere? + enum { USE_LDGSTS = USE_LDGSTS_ }; + + // Ctor. + template + inline __device__ Gmem_tile_qkv_interleaved(Params const& params, int qkv_select, + Block_info const& block_info, int tidx, + int cta_row_offset = 0) + : actual_seqlen_(block_info.actual_seqlen - cta_row_offset), + total_(params.q_stride_in_bytes), + kv_ptr_(reinterpret_cast(params.qkv_ptr)) { + int bidh = block_info.bidh; + int sum_s = block_info.sum_s; + + // We must keep track of the row to repack predicates in load. + row_ = tidx / THREADS_PER_ROW; + // The column. + int col = tidx % THREADS_PER_ROW; + + // h is N + // d is H + // we get the data in as: 3 x h x (d/32) x total x 32 (think 3 x h x (d/32) + // x b x s x 32) + + // Loading qkv: ignore slice for now. + int qkv_offset = qkv_select * params.h * NUM_SLICES * total_; + // bidh * GROUPS * B * S + b * S. + int block_offset = bidh * NUM_SLICES * total_ + sum_s; + // The row offset. + int row_offset = (qkv_offset + block_offset + cta_row_offset) * BYTES_PER_ROW; + + // That's the pointer to load from (see "load"). + kv_ptr_ += row_offset + col * BYTES_PER_LDG; + + init_actual_seqlen_ = actual_seqlen_; + init_kv_ptr_ = kv_ptr_; + } + + // Store data to shared memory. + template + inline __device__ void commit(Smem_tile& smem_tile) { + if (!USE_LDGSTS) { + smem_tile.store(fetch_); + } + } + + // Load data from memory. + template + inline __device__ void load(Smem_tile& smem_tile) { + void const* ptrs[LDGS]; + uint32_t preds[LDGS]; + +// We precompute slice offsets and predicates +#pragma unroll + for (int ii = 0; ii < LDGS; ii++) { + // the next row + int row_i = row_ + ii * ROWS_PER_LDG; + + // Decompose the current row in slice and original row + int slice = row_i / ROWS; + // The position in the slice. + int row_in_slice = row_i % ROWS; + + // Update the predicate. + preds[ii] = row_in_slice < min(actual_seqlen_, ROWS); + // Compute the pointer. + ptrs[ii] = &kv_ptr_[(slice * total_ + row_in_slice) * BYTES_PER_ROW]; + } + + // Update the predicate register. + fmha::pack_predicates(preds_, preds); + + // Trigger the loads. + if (USE_LDGSTS) { + smem_tile.store(ptrs, preds_); + } else { + fmha::ldg(fetch_, ptrs, preds_); + } + } + + // Move the pointer to the next location. + inline __device__ void move(int const steps = 1) { + kv_ptr_ += (int64_t)ROWS * BYTES_PER_ROW * steps; + actual_seqlen_ -= ROWS * steps; + } + + // Reset to the initial location. + inline __device__ void reset() { + kv_ptr_ = init_kv_ptr_; + actual_seqlen_ = init_actual_seqlen_; + } + + // The pointer. + char const* kv_ptr_; + char const* init_kv_ptr_; + // The register to store predicates. + uint32_t preds_[PRED_REGS]; + // The fetch registers. + uint4 fetch_[LDGS]; + // keep track of the row the thread is processing as we move the tile + int row_; + // The sequence length. + int actual_seqlen_; + int init_actual_seqlen_; + // The number of rows per slice?? + int total_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace v2 +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/hopper/arrive_wait.h b/csrc/fmha_v2/fmha/hopper/arrive_wait.h new file mode 100644 index 0000000000..6448d82607 --- /dev/null +++ b/csrc/fmha_v2/fmha/hopper/arrive_wait.h @@ -0,0 +1,396 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +// CP ASYNC FEATURES /////////////////////////////////////////////////////////////////////////////// +#if !defined(CUDA_CP_ASYNC_SUPPORTED) && \ + ((__CUDACC_VER_MAJOR__ >= 11) || \ + ((__CUDACC_VER_MAJOR__ == 10) && (__CUDACC_VER_MINOR__ >= 2))) +#define CUDA_CP_ASYNC_SUPPORTED 1 +#endif + +#if !defined(CUDA_CP_ASYNC_ENABLED) && (CUDA_CP_ASYNC_SUPPORTED) +#define CUDA_CP_ASYNC_ENABLED 1 +#endif + +#if CUDA_CP_ASYNC_ENABLED && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) +#define CUDA_CP_ASYNC_ACTIVATED 1 +#endif + +#if !defined(CUDA_CP_ASYNC_GROUP_POLICY_SUPPORTED) && (CUDA_CP_ASYNC_SUPPORTED) && \ + (__CUDACC_VER_MAJOR__ >= 11) +#define CUDA_CP_ASYNC_GROUP_POLICY_SUPPORTED 1 +#endif + +#if !defined(CUDA_CP_ASYNC_GROUP_POLICY_ENABLED) && (CUDA_CP_ASYNC_GROUP_POLICY_SUPPORTED) +#define CUDA_CP_ASYNC_GROUP_POLICY_ENABLED 1 +#endif + +#if CUDA_CP_ASYNC_GROUP_POLICY_ENABLED && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) +#define CUDA_CP_ASYNC_GROUP_POLICY_ACTIVATED 1 +#endif + +#if !defined(CUDA_CP_ASYNC_MBARRIER_ARRIVE_SUPPORTED) && (CUDA_CP_ASYNC_SUPPORTED) && \ + (__CUDACC_VER_MAJOR__ >= 11) +#define CUDA_CP_ASYNC_MBARRIER_ARRIVE_SUPPORTED 1 +#endif + +#if !defined(CUDA_CP_ASYNC_MBARRIER_ARRIVE_ENABLED) && (CUDA_CP_ASYNC_MBARRIER_ARRIVE_SUPPORTED) +#define CUDA_CP_ASYNC_MBARRIER_ARRIVE_ENABLED 1 +#endif + +#if (CUDA_CP_ASYNC_MBARRIER_ARRIVE_ENABLED) && (__CUDA_ARCH__ >= 800) +#define CUDA_CP_ASYNC_MBARRIER_ARRIVE_ACTIVATED 1 +#endif + +#if (CUDA_CP_ASYNC_MBARRIER_ARRIVE_ACTIVATED) && (CUDACC_VERSION >= 111) +#define CUDA_CP_ASYNC_MBARRIER_WAIT_ACTIVATED 1 +#endif + +#if !defined(FMHA_PTX_MBARRIER_TRYWAIT_NOSLEEP_INTERNAL_SUPPORT_ENABLED) +#define FMHA_PTX_MBARRIER_TRYWAIT_NOSLEEP_INTERNAL_SUPPORT_ENABLED 0 +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +inline __device__ void named_barrier_arrive(uint32_t BARRIER_ID, uint32_t NUM_THREADS) { + if (NUM_THREADS > 1) { + asm volatile("bar.arrive %0, %1;" : : "r"(BARRIER_ID), "r"(NUM_THREADS)); + } +} + +inline __device__ void named_barrier_wait(uint32_t BARRIER_ID, uint32_t NUM_THREADS) { + if (NUM_THREADS > 1) { + asm volatile("bar.sync %0, %1;" ::"r"(BARRIER_ID), "r"(NUM_THREADS)); + } +} + +// it is executed per thread, i.e., each thread can call and init a barrier. +// need a bar.sync after using it. +inline __device__ void bar_create(void* bar_ptr, int init_count) { + unsigned smem_ptr = __nvvm_get_smem_pointer(bar_ptr); + + asm volatile( + "{\n\t" +#if CUDA_CP_ASYNC_MBARRIER_ARRIVE_ACTIVATED + "mbarrier.init.shared.b64 [%1], %0; \n\t" +#else + ".reg .s32 negCnt, count, expectedCount;\n\t" + ".reg .s64 comboCnt; \n\t" + "neg.s32 negCnt, %0;\n\t " + "and.b32 count, negCnt, 0x7fffffff; \n\t" + "and.b32 expectedCount, negCnt, 0x3fffffff; \n\t" + "mov.b64 comboCnt, {expectedCount, count}; \n\t" + "st.shared.s64 [%1], comboCnt; \n\t" +#endif + "}" + : + : "r"(init_count), "r"(smem_ptr)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Arrive_wait { + public: + inline __device__ Arrive_wait() { bar_base_ = NULL; } + + inline __device__ Arrive_wait(uint64_t* bar_base, int id = 0) { + bar_base_ = bar_base; + id_ = id; + } + + inline __device__ uint64_t* get_bar_addr(int32_t id) { + return reinterpret_cast(bar_base_ + id); + } + + inline __device__ int bar_peek(int id, unsigned int bar_phase) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + uint64_t* bar_ptr = reinterpret_cast(bar_base_ + id); + unsigned smem_ptr = __nvvm_get_smem_pointer(bar_ptr); + uint32_t result32; +#if FMHA_PTX_MBARRIER_TRYWAIT_NOSLEEP_INTERNAL_SUPPORT_ENABLED + asm volatile( + "{\n\t" + ".reg .pred P3; \n\t" + "mbarrier.try_wait.parity.nosleep.shared.b64 P3, [%1], %2; \n\t" + "selp.b32 %0, 1, 0, P3; \n\t" + "}" + : "=r"(result32) + : "r"(smem_ptr), "r"(bar_phase)); +#else + // public ptx default heruistic generate SASS equal to with .nosleep in internal ptx + asm volatile( + "{\n\t" + ".reg .pred P3; \n\t" + "mbarrier.try_wait.parity.shared.b64 P3, [%1], %2; \n\t" + "selp.b32 %0, 1, 0, P3; \n\t" + "}" + : "=r"(result32) + : "r"(smem_ptr), "r"(bar_phase)); +#endif + return result32; +#else + uint64_t* bar_ptr = reinterpret_cast(bar_base_ + id); + unsigned int output_phase = (bar_ptr[0] >> 63) & 1; + + return output_phase != bar_phase; +#endif + } + + inline __device__ int bar_peek(int id, unsigned int bar_phase, int pred) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + uint64_t* bar_ptr = reinterpret_cast(bar_base_ + id); + unsigned smem_ptr = __nvvm_get_smem_pointer(bar_ptr); + uint32_t result32; +#if FMHA_PTX_MBARRIER_TRYWAIT_NOSLEEP_INTERNAL_SUPPORT_ENABLED + asm volatile( + "{\n\t" + ".reg .pred P3; \n\t" + ".reg .pred P2;\n\t" + "setp.eq.u32 P2, %3, 1;\n\t" + "@P2 mbarrier.try_wait.parity.nosleep.shared.b64 P3, [%1], %2; \n\t" + "selp.b32 %0, 1, 0, P3; \n\t" + "}" + : "=r"(result32) + : "r"(smem_ptr), "r"(bar_phase), "r"(pred)); +#else + // public ptx default heruistic generate SASS equal to with .nosleep in internal ptx + asm volatile( + "{\n\t" + ".reg .pred P3; \n\t" + ".reg .pred P2;\n\t" + "setp.eq.u32 P2, %3, 1;\n\t" + "@P2 mbarrier.try_wait.parity.shared.b64 P3, [%1], %2; \n\t" + "selp.b32 %0, 1, 0, P3; \n\t" + "}" + : "=r"(result32) + : "r"(smem_ptr), "r"(bar_phase), "r"(pred)); +#endif + return result32; +#else + uint64_t* bar_ptr = reinterpret_cast(bar_base_ + id); + unsigned int output_phase = (bar_ptr[0] >> 63) & 1; + + return output_phase != bar_phase; +#endif + } + + inline __device__ void bar_wait(int id, unsigned int bar_phase) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + uint64_t* bar_ptr = reinterpret_cast(bar_base_ + id); + unsigned smem_ptr = __nvvm_get_smem_pointer(bar_ptr); + uint32_t large_val = 0x989680; + asm volatile( + "{\n\t" + ".reg .pred P3; \n\t" + "LAB_WAIT: \n\t" + //"mbarrier.try_wait.parity.b64 P3, [%0], %1; \n\t" + "mbarrier.try_wait.parity.shared.b64 P3, [%0], %1, %2; \n\t" + "@P3 bra.uni DONE; \n\t" + "bra.uni LAB_WAIT; \n\t" + "DONE: \n\t" + "}" + : + : "r"(smem_ptr), "r"(bar_phase), "r"(large_val)); +#else + uint64_t* bar_ptr = reinterpret_cast(bar_base_ + id); + unsigned smem_ptr = __nvvm_get_smem_pointer(bar_ptr); + + asm volatile( + "{\n\t" + ".reg .pred P3; \n\t" +#ifdef CUDA_CP_ASYNC_MBARRIER_WAIT_ACTIVATED + "mbarrier.test_wait.parity.shared.b64 P3, [%0], %1;\n\t" +#else + ".reg .s32 high, low; \n\t" + ".reg .u32 currentPhase; \n\t" + "ld.volatile.shared.v2.s32 { low, high }, [%0]; \n\t" + "shr.u32 currentPhase, high, 31; \n\t" + "setp.ne.u32 P3, currentPhase, %1; \n\t" +#endif + "@P3 bra.uni DONE; \n\t" + "LAB_WAIT: \n\t" +#ifdef CUDA_CP_ASYNC_MBARRIER_WAIT_ACTIVATED + "mbarrier.test_wait.parity.shared.b64 P3, [%0], %1;\n\t" +#else + "ld.volatile.shared.v2.s32 { low, high }, [%0]; \n\t" + "shr.u32 currentPhase, high, 31; \n\t" + "setp.ne.u32 P3, currentPhase, %1; \n\t" +#endif + "@P3 bra.uni DONE; \n\t" + "bra.uni LAB_WAIT; \n\t" + "DONE: \n\t" + "}" + : + : "r"(smem_ptr), "r"(bar_phase)); +#endif + } + + // Set the expected_transaction_count and add 1 arrive count (1 transaction = 1 Byte) + // This PTX maps to SYNCS.ARRIVES.TRANS64.A1TR. + inline __device__ void bar_arrive_set_transactioncnt(int id, int expected_copy_bytes) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + + uint64_t* bar_ptr = reinterpret_cast(bar_base_ + id); + unsigned smem_ptr = __nvvm_get_smem_pointer(bar_ptr); + asm volatile( + "{\n\t" + "mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1; \n\t" + "}" + : + : "r"(smem_ptr), "r"(expected_copy_bytes)); +#endif + } + + // Set the expected_transaction_count and add 1 arrive count (1 transaction = 1 Byte) + // This PTX maps to SYNCS.ARRIVES.TRANS64.A1TR. + inline __device__ void bar_arrive_set_transactioncnt(int id, int expected_copy_bytes, + uint32_t pred) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + + uint64_t* bar_ptr = reinterpret_cast(bar_base_ + id); + unsigned smem_ptr = __nvvm_get_smem_pointer(bar_ptr); + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.eq.u32 p, %2, 1;\n\t" + "@p mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1; \n\t" + "}" + : + : "r"(smem_ptr), "r"(expected_copy_bytes), "r"(pred)); +#endif + } + + // Sends barrier arrive notification to DSMEM + // Note this uses a slightly different syntax compared to normal arrive + // NOTE : Caller has to ensure that set_bar_base_dsmem has been called prior to using this + // This is done as a compiler optimizations (since set barrier base is independent) + inline __device__ void bar_arrive_dsmem(int const& id) { +#if CUDA_CP_ASYNC_MBARRIER_ARRIVE_ACTIVATED + + uint64_t* bar_ptr = reinterpret_cast(bar_base_ + id); + // TODO : check with PTX team on setctarank (currently emitting errors) + // asm volatile("{\n\t" + //"setctarank.shared.u32 %0, %1, %2;\n\t" + //"}" + // : "=r"(dst_ptr) : "r"(smem_ptr), "r"(cta_id)); + + asm volatile( + "{\n\t" + "mbarrier.arrive.b64 _, [%0];\n\t" + "}" + : + : "l"(bar_ptr)); +#endif + } + + // Just a predicated version of the above function + // Manually inlining it - since the compiler generates BRA instructions at the moment + // NOTE : Caller has to ensure that set_bar_base_dsmem has been called prior to using this + // This is done as a compiler optimizations (since set barrier base is independent) + inline __device__ void bar_arrive_dsmem(int const& id, uint32_t const& pred) { +#if CUDA_CP_ASYNC_MBARRIER_ARRIVE_ACTIVATED + asm volatile( + "{\n\t" + " .reg .pred p;\n\t" + " .reg .s64 addr;\n\t" + " .reg .b64 tmp;\n\t" + " setp.eq.u32 p, %2, 1;\n\t" + " mul.wide.s32 tmp, %0, 8;\n\t" + " add.s64 addr, tmp, %1;\n\t" + "@p mbarrier.arrive.b64 _, [addr];\n\t" + "}" + : + : "r"(id), "l"(bar_base_), "r"(pred)); +#endif + } + + // Sets up the base address for arrival with the correct ctaid in cga + inline __device__ void set_bar_base_dsmem(uint32_t const& cta_id) { + bar_base_ = reinterpret_cast( + ((unsigned long long int)bar_base_ & 0xFFFFFFFFF0FFFFFFULL) + (cta_id << 24)); + } + + inline __device__ void bar_arrive_normal(int id, bool flag = true) { +#if CUDA_CP_ASYNC_ACTIVATED && !(CUDA_CP_ASYNC_MBARRIER_ARRIVE_ACTIVATED) + asm("membar.cta;"); +#endif + + // to make distance for the dependence between atoms.arrive and shfl + if (flag == true) { + uint64_t* bar_ptr = reinterpret_cast(bar_base_ + id); + unsigned smem_ptr = __nvvm_get_smem_pointer(bar_ptr); + +#if CUDA_CP_ASYNC_MBARRIER_ARRIVE_ACTIVATED + + asm volatile( + "{\n\t" + ".reg .b64 state; \n\t" + "mbarrier.arrive.shared.b64 state, [%0];\n\t" + "}" + : + : "r"(smem_ptr)); + +#elif CUDA_CP_ASYNC_ACTIVATED + + asm volatile( + "{\n\t" + ".reg .b64 state; \n\t" + "atom.shared.arrive.b64 state, [%0];" + "}" + : + : "r"(smem_ptr)); +#endif + } + } + + inline __device__ void bar_arrive_ldgsts(int id) { + uint64_t* bar_ptr = reinterpret_cast(bar_base_ + id); + unsigned smem_ptr = __nvvm_get_smem_pointer(bar_ptr); + +#if CUDA_CP_ASYNC_MBARRIER_ARRIVE_ACTIVATED + asm volatile("cp.async.mbarrier.arrive.noinc.shared.b64 [%0];" : : "r"(smem_ptr)); +#elif CUDA_CP_ASYNC_ACTIVATED + asm volatile("cp.async.arrive.shared.b64 [%0];" : : "r"(smem_ptr)); +#endif + } + + inline __device__ uint64_t* bar_base() { return bar_base_; } + + private: + // smem barrier base pointer + uint64_t* bar_base_; + // barrier id + int id_; +}; + +// Set the expected_transaction_count and add 1 arrive count (1 transaction = 1 Byte) +// This PTX maps to SYNCS.ARRIVES.TRANS64.A1TR. +inline __device__ void bar_arrive_set_transactioncnt(unsigned smem_ptr, + unsigned expected_copy_bytes) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + asm volatile( + "{\n\t" + "mbarrier.arrive.expect_copy.shared.b64 _, [%0], %1; \n\t" + "}" + : + : "r"(smem_ptr), "r"(expected_copy_bytes)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/hopper/compute_tile.h b/csrc/fmha_v2/fmha/hopper/compute_tile.h new file mode 100644 index 0000000000..e08c36fc7f --- /dev/null +++ b/csrc/fmha_v2/fmha/hopper/compute_tile.h @@ -0,0 +1,503 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once +#include + +namespace fmha { +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Compute_tile_with_gmma {}; + +/* +compute tile used when both operands are coming from SMEM +*/ +template +struct Compute_tile_with_gmma { + static constexpr int NUM_KBLOCKS = Smem_tile_b::BUFFERS_PER_TILE / Cta_tile::WARPS_K; + static_assert(NUM_KBLOCKS * Cta_tile::WARPS_K == Smem_tile_b::BUFFERS_PER_TILE); + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + + // desc for A and B should have the same strategy + static_assert(Smem_tile_a::Gmma_descriptor::GMMA_DESC_SIZE_PER_GROUP == + Smem_tile_b::Gmma_descriptor::GMMA_DESC_SIZE_PER_GROUP, + "GMMA desc for A and B should have the same strategy."); + + // The number of MMAs. + enum { MMAS_M = Mma_tile::MMAS_M }; + + enum { MMAS_N = Mma_tile::MMAS_N }; + + enum { MMAS_K = Mma_tile::MMAS_K }; + + // Ctor. + inline __device__ Compute_tile_with_gmma() {} + + // Ctor, that helps set the gmma descs to support different buffer index as the start address. + inline __device__ Compute_tile_with_gmma(void* a_smem_, void* b_smem_) + : Compute_tile_with_gmma(__nvvm_get_smem_pointer(a_smem_), __nvvm_get_smem_pointer(b_smem_)) { + } + + inline __device__ Compute_tile_with_gmma(uint32_t a_smem_base, uint32_t b_smem_base) + : a_smem_base_(a_smem_base), b_smem_base_(b_smem_base) { + // We always start at buffer 0. + uint32_t a_smem = a_smem_base_; + uint32_t b_smem = b_smem_base_; + +#pragma unroll + for (int mma_m_idx = 0; mma_m_idx < MMAS_M; ++mma_m_idx) { + gmma_desc_a_[mma_m_idx].set_smem_pointer(a_smem + + mma_m_idx * Smem_tile_a::GMMA_GROUP_SMEM_DISTANCE); + // We take the number of buffers directly from the Smem_tile. If we have only one buffer, the + // return offset is 0. + gmma_desc_a_[mma_m_idx].set_max_descriptor_0(Smem_tile_a::BYTES_PER_BUFFER_NO_4LSB * + (Smem_tile_a::BUFFERS_PER_TILE - 1)); + } + +#pragma unroll + for (int mma_n_idx = 0; mma_n_idx < MMAS_N; ++mma_n_idx) { + gmma_desc_b_[mma_n_idx].set_smem_pointer(b_smem + + mma_n_idx * Smem_tile_b::GMMA_GROUP_SMEM_DISTANCE); + gmma_desc_b_[mma_n_idx].set_max_descriptor_0(Smem_tile_b::BYTES_PER_BUFFER_NO_4LSB * + (Smem_tile_b::BUFFERS_PER_TILE - 1)); + } + } + + // move the gmme desc by N buffers. + // Something nice to have if we have persistent kernels. + inline __device__ void increment_N_gmma_desc_group(int N) { +#pragma unroll + for (int idx = 0; idx < Smem_tile_a::Gmma_descriptor::NUM_DESCRIPTORS; ++idx) { +#pragma unroll + for (int mma_m_idx = 0; mma_m_idx < MMAS_M; ++mma_m_idx) { + uint64_t temp_desc = gmma_desc_a_[mma_m_idx].get_descriptor(idx); + int2& tmp = reinterpret_cast(temp_desc); + tmp.x = (tmp.x & 0xFFFF0000) + (a_smem_base_ / 16) + + mma_m_idx * Smem_tile_a::GMMA_GROUP_SMEM_DISTANCE / 16 + + N * Smem_tile_a::BYTES_PER_BUFFER_NO_4LSB; + gmma_desc_a_[mma_m_idx].set_descriptor(idx, temp_desc); + } + +#pragma unroll + for (int mma_n_idx = 0; mma_n_idx < MMAS_N; ++mma_n_idx) { + uint64_t temp_desc = gmma_desc_b_[mma_n_idx].get_descriptor(idx); + int2& tmp = reinterpret_cast(temp_desc); + tmp.x = + (tmp.x & 0xFFFF0000) + (b_smem_base_ / 16) + N * Smem_tile_b::BYTES_PER_BUFFER_NO_4LSB; + gmma_desc_b_[mma_n_idx].set_descriptor(idx, temp_desc); + } + } + } + + // Clear the accumulators. It does nothing as we have a special flag for GMMA. + inline __device__ void clear() { fmha::clear(acc_); } + + // smarter way of increment a group of gmma desc. + // if one of them need to be reset to the first ldgsts buffer + // it is very likely (currently guaranteed) that all of them need to be reset to the first + // ldgsts buffer. + // we do this to save the usage of uniform register. Otherwise, kernel with larger M could not + // achieve sol. + inline __device__ void increment_gmma_desc_group() { + bool reset_buffer_a = + gmma_desc_a_[0].get_descriptor(0) >= gmma_desc_a_[0].get_max_descriptor_0(); + bool reset_buffer_b = + gmma_desc_b_[0].get_descriptor(0) >= gmma_desc_b_[0].get_max_descriptor_0(); + +#pragma unroll + for (int idx = 0; idx < Smem_tile_a::Gmma_descriptor::NUM_DESCRIPTORS; ++idx) { +#pragma unroll + for (int mma_m_idx = 0; mma_m_idx < MMAS_M; ++mma_m_idx) { + uint64_t temp_desc = gmma_desc_a_[mma_m_idx].get_descriptor(idx); + // smem start address is in lower 32bits + int2& tmp = reinterpret_cast(temp_desc); + if (reset_buffer_a) { + tmp.x -= (Smem_tile_a::BUFFERS_PER_TILE - 1) * Smem_tile_a::BYTES_PER_BUFFER_NO_4LSB; + } else { + tmp.x += Smem_tile_a::BYTES_PER_BUFFER_NO_4LSB; + } + + gmma_desc_a_[mma_m_idx].set_descriptor(idx, temp_desc); + } + +#pragma unroll + for (int mma_n_idx = 0; mma_n_idx < MMAS_N; ++mma_n_idx) { + uint64_t temp_desc = gmma_desc_b_[mma_n_idx].get_descriptor(idx); + int2& tmp = reinterpret_cast(temp_desc); + if (reset_buffer_b) { + tmp.x -= (Smem_tile_b::BUFFERS_PER_TILE - 1) * Smem_tile_b::BYTES_PER_BUFFER_NO_4LSB; + } else { + tmp.x += Smem_tile_b::BYTES_PER_BUFFER_NO_4LSB; + } + gmma_desc_b_[mma_n_idx].set_descriptor(idx, temp_desc); + } + } + } + + // smarter way of increment a group of gmma desc. + // if one of them need to be reset to the first ldgsts buffer + // it is very likely (currently guaranteed) that all of them need to be reset to the first + // ldgsts buffer. + // we do this to save the usage of uniform register. Otherwise, kernel with larger M could not + // achieve sol. + inline __device__ void increment_gmma_desc_a_group() { + bool reset_buffer = gmma_desc_a_[0].get_descriptor(0) >= gmma_desc_a_[0].get_max_descriptor_0(); + +#pragma unroll + for (int idx = 0; idx < Smem_tile_b::Gmma_descriptor::NUM_DESCRIPTORS; ++idx) { +#pragma unroll + for (int mma_m_idx = 0; mma_m_idx < MMAS_M; ++mma_m_idx) { + uint64_t temp_desc = gmma_desc_a_[mma_m_idx].get_descriptor(idx); + // smem start address is in lower 32bits + int2& tmp = reinterpret_cast(temp_desc); + if (reset_buffer) { + tmp.x -= (Smem_tile_a::BUFFERS_PER_TILE - 1) * Smem_tile_a::BYTES_PER_BUFFER_NO_4LSB; + } else { + tmp.x += Smem_tile_a::BYTES_PER_BUFFER_NO_4LSB; + } + gmma_desc_a_[mma_m_idx].set_descriptor(idx, temp_desc); + } + } + } + + // smarter way of increment a group of gmma desc. + // if one of them need to be reset to the first ldgsts buffer + // it is very likely (currently guaranteed) that all of them need to be reset to the first + // ldgsts buffer. + // we do this to save the usage of uniform register. Otherwise, kernel with larger M could not + // achieve sol. + template + inline __device__ void increment_gmma_desc_b_group(int N = 1) { + bool reset_buffer = + RESET_CHECK && gmma_desc_b_[0].get_descriptor(0) >= gmma_desc_b_[0].get_max_descriptor_0(); + +#pragma unroll + for (int idx = 0; idx < Smem_tile_b::Gmma_descriptor::NUM_DESCRIPTORS; ++idx) { +#pragma unroll + for (int mma_n_idx = 0; mma_n_idx < MMAS_N; ++mma_n_idx) { + uint64_t temp_desc = gmma_desc_b_[mma_n_idx].get_descriptor(idx); + int2& tmp = reinterpret_cast(temp_desc); + if (reset_buffer) { + tmp.x -= (Smem_tile_b::BUFFERS_PER_TILE - 1) * Smem_tile_b::BYTES_PER_BUFFER_NO_4LSB; + } else { + tmp.x += Smem_tile_b::BYTES_PER_BUFFER_NO_4LSB; + } + gmma_desc_b_[mma_n_idx].set_descriptor(idx, temp_desc); + } + } + } + + // Compute. + // last of group indicates it is the last GMMA with a GMMA group. So the GSB should be updated + // last of kblock indicates it is the last GMMA with kblock. so desc will be updated accordingly + inline __device__ void compute(int ki, bool last_of_group = false, bool last_of_kblock = false) { +#pragma unroll + for (int mmas_m_idx = 0; mmas_m_idx < MMAS_M; ++mmas_m_idx) { +#pragma unroll + for (int mmas_n_idx = 0; mmas_n_idx < MMAS_N; ++mmas_n_idx) { + // weird code to use SEL to avoid reg spill + typename Smem_tile_a::Gmma_descriptor::Single_desc single_desc_a; + typename Smem_tile_b::Gmma_descriptor::Single_desc single_desc_b; + + single_desc_a.set(gmma_desc_a_[mmas_m_idx].get_descriptor(ki)); + single_desc_b.set(gmma_desc_b_[mmas_n_idx].get_descriptor(ki)); + + if (mmas_n_idx == (MMAS_N - 1)) { + // update desc for A + gmma_desc_a_[mmas_m_idx].increment_single_descriptor(last_of_kblock); + } + if (mmas_m_idx == (MMAS_M - 1)) { + // update desc for B + gmma_desc_b_[mmas_n_idx].increment_single_descriptor(last_of_kblock); + } + + if ((last_of_group == true) && (mmas_m_idx == (MMAS_M - 1)) && + (mmas_n_idx == (MMAS_N - 1))) { + // increment the scoreboard + acc_[mmas_m_idx][mmas_n_idx].template mma(single_desc_a, single_desc_b); + } else { + acc_[mmas_m_idx][mmas_n_idx].template mma(single_desc_a, single_desc_b); + } + } // for (mmas_n_idx) + } // for (mmas_m_idx) + } + + // Load from shared memory. For GMMA where both operand comes from SMEM, this does nothing + inline __device__ void load(Smem_tile_a& smem_a, Smem_tile_b& smem_b, int ki, + bool first = false) {} + + // The accumulators. + Fragment_accumulator acc_[MMAS_M][MMAS_N]; + + // one descriptor group per stage, different GMMAs may or maynot share descriptor group + // each descriptor group holds all the descriptors for the entire kblock + + // The descriptor to load A. + typename Smem_tile_a::Gmma_descriptor gmma_desc_a_[MMAS_M]; + // The descriptor to load B. + typename Smem_tile_b::Gmma_descriptor gmma_desc_b_[MMAS_N]; + uint32_t a_smem_base_, b_smem_base_; +}; + +/* +compute tile used when A is from RF, B is from SMEM +*/ +template +struct Compute_tile_with_gmma { + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The fragment for holding A. + using Fragment = Fragment_a; + + // static_assert(Cta_tile::K == 128); + // static_assert(Mma_tile::K_PER_MMA_PER_CTA == 64 ); + // pstatic_assert(NUM_KBLOCKS == 384 / 64); + static constexpr int NUM_KBLOCKS = Smem_tile_b::BUFFERS_PER_TILE / Cta_tile::WARPS_K; + // static_assert(NUM_KBLOCKS * Cta_tile::WARPS_K == Smem_tile_b::BUFFERS_PER_TILE); + + // desc for A and B should have the same strategy + static_assert(Smem_tile_a::Gmma_descriptor::GMMA_DESC_SIZE_PER_GROUP == + Smem_tile_b::Gmma_descriptor::GMMA_DESC_SIZE_PER_GROUP, + "GMMA desc for A and B should have the same strategy."); + + // The number of MMAs. + enum { MMAS_M = Mma_tile::MMAS_M }; + + enum { MMAS_N = Mma_tile::MMAS_N }; + + // TODO + enum { MMAS_K = Mma_tile::MMAS_K * Cta_tile::WARPS_K }; + + // Ctor. + inline __device__ Compute_tile_with_gmma() {} + + // Ctor, that helps set the gmma descs + inline __device__ Compute_tile_with_gmma(void* a_smem_, void* b_smem_) + : Compute_tile_with_gmma(__nvvm_get_smem_pointer(a_smem_), __nvvm_get_smem_pointer(b_smem_)) { + } + + inline __device__ Compute_tile_with_gmma(uint32_t, uint32_t b_smem_base) + : b_smem_base_(b_smem_base) { + // We always start at buffer 0 and take the number of buffers from the Smem_tile, as above. + uint32_t b_smem = b_smem_base_; +// do not need to set desc for matrix A +#pragma unroll + for (int mma_n_idx = 0; mma_n_idx < MMAS_N; ++mma_n_idx) { + gmma_desc_b_[mma_n_idx].set_smem_pointer(b_smem + + mma_n_idx * Smem_tile_b::GMMA_GROUP_SMEM_DISTANCE); + gmma_desc_b_[mma_n_idx].set_max_descriptor_0(Smem_tile_b::BYTES_PER_BUFFER_NO_4LSB * + (Smem_tile_b::BUFFERS_PER_TILE - 1)); + } + } + + // move the gmme desc by N buffers. + // Something nice to have if we have persistent kernels. + inline __device__ void increment_N_gmma_desc_group(int N) { +#pragma unroll + for (int idx = 0; idx < Smem_tile_b::Gmma_descriptor::NUM_DESCRIPTORS; ++idx) { +#pragma unroll + for (int mma_n_idx = 0; mma_n_idx < MMAS_N; ++mma_n_idx) { + uint64_t temp_desc = gmma_desc_b_[mma_n_idx].get_descriptor(idx); + int2& tmp = reinterpret_cast(temp_desc); + tmp.x = + (tmp.x & 0xFFFF0000) + (b_smem_base_ / 16) + (N)*Smem_tile_b::BYTES_PER_BUFFER_NO_4LSB; + gmma_desc_b_[mma_n_idx].set_descriptor(idx, temp_desc); + } + } + } + + // Clear the accumulators. It does nothing as we have a special flag for GMMA. + inline __device__ void clear() { fmha::clear(acc_); } + + // smarter way of increment a group of gmma desc. + // if one of them need to be reset to the first ldgsts buffer + // it is very likely (currently guaranteed) that all of them need to be reset to the first + // ldgsts buffer. + // we do this to save the usage of uniform register. Otherwise, kernel with larger M could not + // achieve sol. + + template + inline __device__ void increment_gmma_desc_group(int N = 1) { + bool reset_buffer = + RESET_CHECK && gmma_desc_b_[0].get_descriptor(0) >= gmma_desc_b_[0].get_max_descriptor_0(); + +#pragma unroll + for (int idx = 0; idx < Smem_tile_b::Gmma_descriptor::NUM_DESCRIPTORS; ++idx) { +#pragma unroll + for (int mma_n_idx = 0; mma_n_idx < MMAS_N; ++mma_n_idx) { + uint64_t temp_desc = gmma_desc_b_[mma_n_idx].get_descriptor(idx); + int2& tmp = reinterpret_cast(temp_desc); + if (reset_buffer) { + tmp.x -= (Smem_tile_b::BUFFERS_PER_TILE - 1) * Smem_tile_b::BYTES_PER_BUFFER_NO_4LSB; + } else { + tmp.x += Smem_tile_b::BYTES_PER_BUFFER_NO_4LSB; + } + gmma_desc_b_[mma_n_idx].set_descriptor(idx, temp_desc); + } + } + } + + // Compute. + // last of group indicates it is the last GMMA with a GMMA group. So the GSB should be updated + // last of kblock indicates it is the last GMMA with kblock. so desc will be updated accordingly + inline __device__ void compute(int ki, bool last_of_group = false, bool last_of_kblock = false) { +#pragma unroll + for (int mmas_m_idx = 0; mmas_m_idx < MMAS_M; ++mmas_m_idx) { +#pragma unroll + for (int mmas_n_idx = 0; mmas_n_idx < MMAS_N; ++mmas_n_idx) { + // weird code to use SEL to avoid reg spill + typename Smem_tile_b::Gmma_descriptor::Single_desc single_desc_b; + + single_desc_b.set(gmma_desc_b_[mmas_n_idx].get_descriptor(ki)); + + if (mmas_m_idx == (MMAS_M - 1)) { + // update desc for B + gmma_desc_b_[mmas_n_idx].increment_single_descriptor(last_of_kblock); + } + + if ((last_of_group == true) && (mmas_m_idx == (MMAS_M - 1)) && + (mmas_n_idx == (MMAS_N - 1))) { + // increment the scoreboard + acc_[mmas_m_idx][mmas_n_idx].template mma(a_[mmas_m_idx], single_desc_b); + } else { + acc_[mmas_m_idx][mmas_n_idx].template mma(a_[mmas_m_idx], single_desc_b); + } + } // for (mmas_n_idx) + } // for (mmas_m_idx) + } + + template + inline __device__ void compute_incta_splitk(Fragment const (&frag_a)[K][1], int const warp_k) { + if (Smem_tile_b::Gmma_descriptor::TRANS_MODE == Gmma_descriptor_transpose::NOTRANS) { + // In this case, the K dimension is the leading dimension, so we need to set the smem + // locations correctly for each Warp in K. + + // The number of elements in K per group. + constexpr int ELTS_PER_KGROUP = Smem_tile_b::BYTES_PER_ROW / sizeof(typename Traits::B_type); + // The number of MMAS to perform before incrementing by the group stride. + constexpr int MMAS_K_PER_GROUP = ELTS_PER_KGROUP / Traits::GMMA_K; + // The number of MMAS a k-warp performs. + constexpr int MMAS_K_PER_WARP = Mma_tile::MMAS_K; + + int const group_offset = warp_k * MMAS_K_PER_WARP; + // Initialize the descriptor + int gi = group_offset / MMAS_K_PER_GROUP; + int ii = group_offset % MMAS_K_PER_GROUP; + + int BYTES_OFFSET_NO_4LSB = gi * Smem_tile_b::BYTES_PER_BUFFER_NO_4LSB + + ii * Smem_tile_b::Gmma_descriptor::BYTES_PER_DESC_NO_4LSB; + + uint64_t desc_b = gmma_desc_b_[0].get_descriptor(0); + int2& desc_b_view = reinterpret_cast(desc_b); + desc_b_view.x += BYTES_OFFSET_NO_4LSB; + + typename Smem_tile_b::Gmma_descriptor::Single_desc single_desc_b; + single_desc_b.set(desc_b); +#pragma unroll + for (int ki = 0; ki < MMAS_K_PER_WARP - 1; ki++) { + acc_[0][0].template mma(frag_a[ki][0], single_desc_b); + + // Increment the descriptor for the next kblock. + int const ki_next = group_offset + ki + 1; + // Update descriptor for next GMMA. + if (ki_next % MMAS_K_PER_GROUP == 0) { + desc_b_view.x += Smem_tile_b::BYTES_PER_BUFFER_NO_4LSB - + Smem_tile_b::Gmma_descriptor::BYTES_DESC_INC_BOUNDARY_NO_4LSB; + } else { + desc_b_view.x += Smem_tile_b::Gmma_descriptor::BYTES_PER_DESC_NO_4LSB; + } + single_desc_b.set(desc_b); + } + // Last one increments gsb. + acc_[0][0].template mma(frag_a[MMAS_K_PER_WARP - 1][0], single_desc_b); + } else { // GMMA supports transposed input: we can just advance SMEM address to the k-th block + // for each Warp in K. + + constexpr int NUM_KGROUPS = Smem_tile_b::BUFFERS_PER_TILE; + constexpr int MMAS_K_PER_GROUP = Mma_tile::MMAS_K / NUM_KGROUPS; + static_assert(MMAS_K_PER_GROUP * NUM_KGROUPS == Mma_tile::MMAS_K); + + uint64_t temp_desc = gmma_desc_b_[0].get_descriptor(0); + int2& tmp = reinterpret_cast(temp_desc); + + constexpr int BYTES_PER_K_GROUP_NO_4LSB = + Mma_tile::K_PER_WARP_GROUP * Mma_tile::N_PER_WARP_GROUP * sizeof(Traits::B_type) / 16; + tmp.x += warp_k * BYTES_PER_K_GROUP_NO_4LSB; + gmma_desc_b_[0].set_descriptor(0, temp_desc); + +#pragma unroll + for (int kbi = 0; kbi < NUM_KGROUPS - 1; kbi++) { +#pragma unroll + for (int ki = 0; ki < MMAS_K_PER_GROUP; ki++) { + fill_frag_a(frag_a[kbi * MMAS_K_PER_GROUP + ki][0]); + // Never increment scoreboard, but check for last kblock. + compute(ki, false, ki == MMAS_K_PER_GROUP - 1); + } + increment_gmma_desc_group(); + } + +#pragma unroll + for (int ki = 0; ki < MMAS_K_PER_GROUP - 1; ki++) { + fill_frag_a(frag_a[(NUM_KGROUPS - 1) * MMAS_K_PER_GROUP + ki][0]); + compute(ki); + } + + fill_frag_a(frag_a[NUM_KGROUPS * MMAS_K_PER_GROUP - 1][0]); + compute(NUM_KGROUPS * MMAS_K_PER_GROUP - 1, true, true); + } + } + + // Fill the input fragment + inline __device__ void fill_frag_a(Fragment a_temp) { +#pragma unroll + for (int idx = 0; idx < Fragment::NUM_REGS; ++idx) { + a_[0].reg(idx) = a_temp.reg(idx); + } + } + + // Load from shared memory. + // we don't actually need this with MHA fused kernel. + inline __device__ void load(Smem_tile_a& smem_a, Smem_tile_b& smem_b, int ki) { + // smem_a.load( a_[ki], ki ); + } + + // The accumulators. + Fragment_accumulator acc_[MMAS_M][MMAS_N]; + + // The fragments to load A. + // Need to think about is is better to declare as Fragment a_? + // for the second GEMM, MMAS_M is most likely 1. (at least for now. ) + Fragment a_[MMAS_M]; + + // one descriptor group per stage, different GMMAs may or maynot share descriptor group + // each descriptor group holds all the descriptors for the entire kblock + + // The descriptor to load B. + typename Smem_tile_b::Gmma_descriptor gmma_desc_b_[MMAS_N]; + uint32_t b_smem_base_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/hopper/fragment.h b/csrc/fmha_v2/fmha/hopper/fragment.h new file mode 100644 index 0000000000..0ee3c7e5be --- /dev/null +++ b/csrc/fmha_v2/fmha/hopper/fragment.h @@ -0,0 +1,491 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include +#include +#include +#include + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// F R A G M E N T (A) +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Only needed if Operand A is coming from RF. +template +struct Fragment_a, Layout> + : public Fragment { + // A should be coming from RF. + static_assert(A_RF, "A_RF must be true to allocate RF for Operand A.\n"); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Only needed if Operand A is coming from RF. +template +struct Fragment_a, Layout> + : public Fragment { + // A should be coming from RF. + static_assert(A_RF, "A_RF must be true to allocate RF for Operand A.\n"); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Only needed if Operand A is coming from RF. +template +struct Fragment_a, Layout> + : public Fragment { + // A should be coming from RF. + static_assert(GMMA_A_RF == true, "GMMA_A_RF must be true to allocate RF for Operand A.\n"); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Only needed if Operand A is coming from RF. +template +struct Fragment_a, + Layout> + : public Fragment { + // A should be coming from RF. + static_assert(GMMA_A_RF == true, "GMMA_A_RF must be true to allocate RF for Operand A.\n"); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_a, + Layout> + // TODO: Do we need the * 4 or not? + : public Fragment { + static_assert(sizeof(Input_type_A) == 1); + static_assert(sizeof(Input_type_B) == 1); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// H G M M A . F 1 6 +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// both operands are coming from SMEM + +template +struct Fragment_accumulator> + : public Fragment { + // The base class. + using Base = Fragment; + + // Add two fragments. + template + inline __device__ void add(Other_fragment_ const& other) { + for (int ii = 0; ii < Base::NUM_REGS; ++ii) { + this->reg(ii) = hadd2(this->reg(ii), other.reg(ii)); + } + } + + // Do the GMMA. + template + inline __device__ void mma(Gmma_single_desc_a const& single_desc_a, + Gmma_single_desc_b const& single_desc_b) { + // call hgmma + fmha::hgmma_fp16< + Gmma_single_desc_a::TRANS_MODE == fmha::Gmma_descriptor_transpose::TRANS ? true : false, + Gmma_single_desc_b::TRANS_MODE == fmha::Gmma_descriptor_transpose::TRANS ? true : false, + GMMA_N, INCREMENT_SCORE_BOARD>(single_desc_a.get(), single_desc_b.get(), this->regs_); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// both operands are coming from SMEM + +template +struct Fragment_accumulator> + : public Fragment { + // The base class. + using Base = Fragment; + + // Add two fragments. + template + inline __device__ void add(Other_fragment_ const& other) { + for (int ii = 0; ii < Base::NUM_REGS; ++ii) { + this->elt(ii) = this->elt(ii) + other.elt(ii); + } + } + + // Do the GMMA. + template + inline __device__ void mma(Gmma_single_desc_a const& single_desc_a, + Gmma_single_desc_b const& single_desc_b) { + // call hgmma + fmha::hgmma_bf16< + Gmma_single_desc_a::TRANS_MODE == fmha::Gmma_descriptor_transpose::TRANS ? true : false, + Gmma_single_desc_b::TRANS_MODE == fmha::Gmma_descriptor_transpose::TRANS ? true : false, + GMMA_N, INCREMENT_SCORE_BOARD>(single_desc_a.get(), single_desc_b.get(), this->regs_); + } +}; + +////////////////////////////////////////////////////////////////////////////////////////////////// +// A is coming from RF; B is coming from SMEM + +template +struct Fragment_accumulator> + : public Fragment { + // The base class. + using Base = Fragment; + + // The Traits + using Traits = Hopper_hgmma_fp16_traits; + + // Add two fragments. + template + inline __device__ void add(Other_fragment_ const& other) { + for (int ii = 0; ii < Base::NUM_REGS; ++ii) { + this->reg(ii) = hadd2(this->reg(ii), other.reg(ii)); + } + } + + // Do the GMMA. + template + inline __device__ void mma(Fragment_a const& a, + Gmma_single_desc_b const& single_desc_b) { + // call hgmma + fmha::hgmma_rfa_fp16< + Gmma_single_desc_b::TRANS_MODE == fmha::Gmma_descriptor_transpose::TRANS ? true : false, + GMMA_N, INCREMENT_SCORE_BOARD>(a.regs_, single_desc_b.get(), this->regs_); + } +}; + +////////////////////////////////////////////////////////////////////////////////////////////////// +// A is coming from RF; B is coming from SMEM + +template +struct Fragment_accumulator> + : public Fragment { + // The base class. + using Base = Fragment; + + // The Traits + using Traits = Hopper_hgmma_bf16_traits; + + // Add two fragments. + template + inline __device__ void add(Other_fragment_ const& other) { + for (int ii = 0; ii < Base::NUM_ELTS; ++ii) { + this->elt(ii) = this->elt(ii) + other.elt(ii); + } + } + + // Do the GMMA. + template + inline __device__ void mma(Fragment_a const& a, + Gmma_single_desc_b const& single_desc_b) { + // call hgmma + fmha::hgmma_rfa_bf16< + Gmma_single_desc_b::TRANS_MODE == fmha::Gmma_descriptor_transpose::TRANS ? true : false, + GMMA_N, INCREMENT_SCORE_BOARD>(a.regs_, single_desc_b.get(), this->regs_); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// H G M M A . F 3 2 +// +////////////////////////////////////////////////////////////////////////////////////////////////// +// both operands are coming from SMEM +template +struct Fragment_accumulator> + : public Fragment { + // The base class. + using Base = Fragment; + + // Add two fragments. + template + inline __device__ void add(Other_fragment_ const& other) { + for (int ii = 0; ii < Base::NUM_ELTS; ++ii) { + this->elt(ii) = this->elt(ii) + other.elt(ii); + } + } + + // Do the GMMA. + template + inline __device__ void mma(Gmma_single_desc_a const& single_desc_a, + Gmma_single_desc_b const& single_desc_b) { + // call hgmma + fmha::hgmma_fp32< + Gmma_single_desc_a::TRANS_MODE == fmha::Gmma_descriptor_transpose::TRANS ? true : false, + Gmma_single_desc_b::TRANS_MODE == fmha::Gmma_descriptor_transpose::TRANS ? true : false, + GMMA_N, INCREMENT_SCORE_BOARD>(single_desc_a.get(), single_desc_b.get(), this->regs_); + } +}; + +// +//////////////////////////////////////////////////////////////////////////////////////////////////// +// A is coming from RF; B is coming from SMEM +template +struct Fragment_accumulator> + : public Fragment { + // The base class. + using Base = Fragment; + + // The Traits + using Traits = Hopper_hgmma_fp32_traits; + + // Add two fragments. + template + inline __device__ void add(Other_fragment_ const& other) { + for (int ii = 0; ii < Base::NUM_ELTS; ++ii) { + this->elt(ii) = this->elt(ii) + other.elt(ii); + } + } + + // Do the GMMA. + template + inline __device__ void mma(Fragment_a const& a, + Gmma_single_desc_b const& single_desc_b) { + // call hgmma + fmha::hgmma_rfa_fp32< + Gmma_single_desc_b::TRANS_MODE == fmha::Gmma_descriptor_transpose::TRANS ? true : false, + GMMA_N, INCREMENT_SCORE_BOARD>(a.regs_, single_desc_b.get(), this->regs_); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Q G M M A . F 3 2 +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// I G M M A . I N T 8 +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Both operands are coming from SMEM. +template +struct Fragment_accumulator> + : public Fragment { + // The base class. + using Base = Fragment; + + // Do the GMMA. + template + inline __device__ void mma(Gmma_single_desc_a const& single_desc_a, + Gmma_single_desc_b const& single_desc_b) { + fmha::igmma_int8_int32(single_desc_a.get(), single_desc_b.get(), + this->regs_); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// A is coming from RF; B is coming from SMEM + +template +struct Fragment_accumulator> + : public Fragment { + // The base class. + using Base = Fragment; + + // The Traits. + using Traits = Hopper_igmma_int8_int32_traits; + + // Do the GMMA. + template + inline __device__ void mma(Fragment_a const& a, + Gmma_single_desc_b const& single_desc_b) { + fmha::igmma_rfa_int8_int32(a.regs_, single_desc_b.get(), + this->regs_); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Fp32 Accumulator A operand from RF and B operand from SMEM +template +struct Fragment_accumulator> + : public Fragment { + // The base class. + using Base = Fragment; + + // Add two fragments. + template + inline __device__ void add(Other_fragment_ const& other) { + for (int ii = 0; ii < Base::NUM_ELTS; ++ii) { + this->elt(ii) = this->elt(ii) + other.elt(ii); + } + } + + // The Traits + using Traits = Hopper_qgmma_fp8_fp32_traits; + + // Do the GMMA. + template + inline __device__ void mma(Fragment_a const& a, + Gmma_single_desc_b const& single_desc_b) { + // call hgmma + if (std::is_same_v && std::is_same_v) { + qgmma_rfa_e4m3_e4m3_fp32(a.regs_, single_desc_b.get(), + this->regs_); + } else if (std::is_same_v && std::is_same_v) { + qgmma_rfa_e5m2_e4m3_fp32(a.regs_, single_desc_b.get(), + this->regs_); + } else if (std::is_same_v && std::is_same_v) { + qgmma_rfa_e4m3_e5m2_fp32(a.regs_, single_desc_b.get(), + this->regs_); + } else if (std::is_same_v && std::is_same_v) { + qgmma_rfa_e5m2_e5m2_fp32(a.regs_, single_desc_b.get(), + this->regs_); + } else { + assert(false && "unsupported"); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// fp32 accumulator +// Both operands are coming from SMEM. +template +struct Fragment_accumulator> + : public Fragment { + // The base class. + using Base = Fragment; + + // Do the GMMA. + template + inline __device__ void mma(Gmma_single_desc_a const& single_desc_a, + Gmma_single_desc_b const& single_desc_b) { + if (std::is_same_v && std::is_same_v) { + qgmma_e4m3_e4m3_fp32(single_desc_a.get(), single_desc_b.get(), + this->regs_); + } else if (std::is_same_v && std::is_same_v) { + qgmma_e5m2_e4m3_fp32(single_desc_a.get(), single_desc_b.get(), + this->regs_); + } else if (std::is_same_v && std::is_same_v) { + qgmma_e4m3_e5m2_fp32(single_desc_a.get(), single_desc_b.get(), + this->regs_); + } else if (std::is_same_v && std::is_same_v) { + qgmma_e5m2_e5m2_fp32(single_desc_a.get(), single_desc_b.get(), + this->regs_); + } else { + assert(false && "unsupported"); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax_saver_tma { + // Warps. + enum { WARPS_M = Cta_tile::WARPS_M }; + + enum { WARPS_N = Cta_tile::WARPS_N }; + + enum { WARPS_K = Cta_tile::WARPS_K }; + + // Ctor. + template + inline __device__ Softmax_saver_tma(Params const& params, Head_info const& head_info) + : actual_len_(head_info.actual_seqlen), + local_q_tile_offset_(head_info.local_q_tile_offset), + softmax_sum_ptr_(reinterpret_cast(params.softmax_stats_ptr)), + softmax_stats_stride_in_bytes_(params.softmax_stats_stride_in_bytes) { + softmax_max_ptr_ = reinterpret_cast(params.softmax_stats_ptr); + int warp = (threadIdx.x % 128) / Cta_tile::THREADS_PER_WARP; + int lane = threadIdx.x % Cta_tile::THREADS_PER_WARP; + // MMA row0 index (8x4 thread layout) + row0_ = warp * Mma_tile::M_PER_MMA / WARPS_M + (lane / 4); + + int sum_s = + params.is_s_padded ? params.s * head_info.bidb : params.cu_q_seqlens[head_info.bidb]; + int token_id = sum_s * params.h + head_info.bidh; + size_t const bh_offset = + token_id * sizeof(float) * 2 + local_q_tile_offset_ * softmax_stats_stride_in_bytes_; + softmax_max_ptr_ += bh_offset + row0_ * softmax_stats_stride_in_bytes_; + softmax_sum_ptr_ += bh_offset + row0_ * softmax_stats_stride_in_bytes_ + sizeof(float); + }; + + inline __device__ void store(float* p_sum, float* p_max, float sqrt_d, int row_offset, + bool valid_run) { + // Four threads process two rows in mma, each row has one softmax_sum and one softmax_max. + // Here we use one thread to write one softmax element. + float values; + int lane = threadIdx.x % Cta_tile::THREADS_PER_WARP; + if (lane % 4 < 2) { + values = p_sum[lane % 2]; + } else { + values = p_max[lane % 2] / sqrt_d; + } + if (!valid_run && (lane % 4) < 2) { + values = 1.0; + } + char* dst_ptr = (lane % 4 < 2) ? softmax_sum_ptr_ : softmax_max_ptr_; + size_t off_inside_mma = (lane % 2 == 0) ? row_offset : row_offset + 8; + if (local_q_tile_offset_ + row0_ + off_inside_mma < actual_len_) { + fmha::stg(dst_ptr + off_inside_mma * softmax_stats_stride_in_bytes_, values); + } + } + + // ptr + char* softmax_sum_ptr_ = nullptr; + char* softmax_max_ptr_ = nullptr; + + // the first row's idx + int row0_; + // actual seq length + int const actual_len_; + int const softmax_stats_stride_in_bytes_; + int const local_q_tile_offset_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/hopper/gmem_tile_o_packed.h b/csrc/fmha_v2/fmha/hopper/gmem_tile_o_packed.h new file mode 100644 index 0000000000..7c9ac43bb8 --- /dev/null +++ b/csrc/fmha_v2/fmha/hopper/gmem_tile_o_packed.h @@ -0,0 +1,1138 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once +#include +#include + +namespace fmha { + +namespace v2 { + +template +struct Gmem_tile_o_hopper {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Not super proud of this. Need to refactor. +// A not optimized way of storing tile_O, without SMEM swizzle. +// STG.32 is going to be used. +template +struct Gmem_tile_o_hopper_16bits { + // The associated MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of elements per STG. + enum { ELEMENTS_PER_STG = 2 }; + + // The size in bytes of each element. + enum { BYTES_PER_ELEMENT = 2 }; + + // The size of each STG. + enum { BYTES_PER_STG = ELEMENTS_PER_STG * BYTES_PER_ELEMENT }; + + // The size of a row in bytes. + enum { BYTES_PER_ROW = Cta_tile::VALID_N * BYTES_PER_ELEMENT }; + + // The number of rows accessed by each thread. + enum { ROWS_PER_THREAD = Mma_tile::M_PER_MMA / 8 / Cta_tile::WARPS_PER_CTA }; + + enum { ROWS = Cta_tile::M }; + + // The number of columns access by each thread. + // Note there are 2 elements per reg. + enum { COLS_PER_THREAD = Mma_tile::N_PER_MMA / 4 / 2 }; + + // The number of valid columns (stored to GMEM) by each thread. + enum { + VALID_COLS_PER_THREAD_FOR_LAST_MMA = (Cta_tile::VALID_N % Mma_tile::N_PER_MMA) == 0 + ? COLS_PER_THREAD + : (Cta_tile::VALID_N % Mma_tile::N_PER_MMA) / 8 + }; + + enum { VALID_MMAS_N = fmha::Div_up::VALUE }; + + static_assert(Cta_tile::VALID_N % 8 == 0, "The valid head dimension needs to be multiple of 8."); + + // The number of accumulator held by each thread, per HGMMA instruction. + enum { ELTS_PER_THREAD = ROWS_PER_THREAD * COLS_PER_THREAD }; + + // Currently, we assume for o matrix, GMMA M/N shape matches CTA M/N shape. + static_assert(Mma_tile::M_PER_MMA == Cta_tile::M && + Mma_tile::N_PER_MMA * Mma_tile::MMAS_N == Cta_tile::N, + "Currently, we assume for o matrix, GMMA M shape matches CTA M shape. "); + + // Step N for one quad + enum { STEP_N = 8 * BYTES_PER_ELEMENT }; + + // Ctor. + template + inline __device__ Gmem_tile_o_hopper_16bits(Params const& params, Block_info const& block_info, + int tidx, int cta_row_offset = 0) + : params_o_stride_in_bytes_(params.o_stride_in_bytes), + actual_seqlen_(block_info.actual_seqlen), + o_ptr_(reinterpret_cast(params.o_ptr)) { + // Decompose the position of the thread into warp/lane. + int warp = tidx / Cta_tile::THREADS_PER_WARP; + int lane = tidx % Cta_tile::THREADS_PER_WARP; + + // int warpgroup_idx = warp / 4; + int warp_idx_within_warpgroup = warp % 4; + + // Compute the position in the sequence (within the CTA for the moment). + int row = warp_idx_within_warpgroup * (Mma_tile::M_PER_MMA / 4) + lane / 4; + // Store the row to update the predicates in load. + row_ = cta_row_offset + row; + // Compute the position of the thread in the row. + int col = lane % 4 * ELEMENTS_PER_STG; + + // The offset of the 1st row written by the thread. We store the P matrix interleaved. + int64_t row_offset = + (int64_t)row_ * params_o_stride_in_bytes_ + block_info.bidx * BYTES_PER_ROW; + // Finalize the pointer. + o_ptr_ += row_offset + col * BYTES_PER_ELEMENT; + } + + // Store data to memory. + template + inline __device__ void store(Accumulators const (&acc)[M][N]) { + int64_t const step_m = 8 * (this->params_o_stride_in_bytes_); + // we assume M = 1. some shortcuts. + static_assert(M == 1); +#pragma unroll + for (int row_idx = 0; row_idx < ROWS_PER_THREAD; ++row_idx) { + if (row_ + row_idx * 8 >= actual_seqlen_) { + break; + } +#pragma unroll + for (int mma_ni = 0; mma_ni < VALID_MMAS_N - 1; ++mma_ni) { +#pragma unroll + for (int col_idx = 0; col_idx < COLS_PER_THREAD; ++col_idx) { + uint32_t acc_0 = acc[0][mma_ni].reg(col_idx * ROWS_PER_THREAD + row_idx); + + int64_t offset = + (int64_t)row_idx * step_m + (int64_t)(col_idx + mma_ni * COLS_PER_THREAD) * STEP_N; + fmha::stg(o_ptr_ + offset, acc_0); + } // col_idx + } // mma_ni + + // The last mma_n may not store full elements back to GMEM. + int mma_ni = VALID_MMAS_N - 1; +#pragma unroll + for (int col_idx = 0; col_idx < VALID_COLS_PER_THREAD_FOR_LAST_MMA; ++col_idx) { + uint32_t acc_0 = acc[0][mma_ni].reg(col_idx * ROWS_PER_THREAD + row_idx); + + int64_t offset = + (int64_t)row_idx * step_m + (int64_t)(col_idx + mma_ni * COLS_PER_THREAD) * STEP_N; + fmha::stg(o_ptr_ + offset, acc_0); + } // col_idx + } // row_idx + } + + // Move to the next location. + inline __device__ void move() { + row_ += ROWS; + o_ptr_ += (int64_t)ROWS * params_o_stride_in_bytes_; + } + + // The stride between rows for the QKV matrice. + int64_t params_o_stride_in_bytes_; + // The pointer. + char* o_ptr_; + // Is the thread active for the last STG? + int is_active_for_last_stg_; + + // The row loaded by this thread. + int row_; + // The length of the sequence loaded by that CTA. + int actual_seqlen_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o_hopper< + fmha::Hopper_hgmma_fp16_traits, Cta_tile, + 1> // WARPS_K + : public Gmem_tile_o_hopper_16bits< + fmha::Hopper_hgmma_fp16_traits, Cta_tile> { + using Traits = fmha::Hopper_hgmma_fp16_traits; + + using Base = Gmem_tile_o_hopper_16bits< + fmha::Hopper_hgmma_fp16_traits, Cta_tile>; + + template + inline __device__ Gmem_tile_o_hopper(Params const& params, Block_info const& block_info, Shared&&, + int tidx, int cta_row_offset = 0) + : Base(params, block_info, tidx, cta_row_offset) { + static_assert(!std::is_same::value, "Check constructor argument type!"); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o_hopper< + fmha::Hopper_hgmma_fp32_traits, Cta_tile, + 1> // WARPS_K + : public Gmem_tile_o_hopper_16bits< + fmha::Hopper_hgmma_fp32_traits, Cta_tile> { + using Traits = fmha::Hopper_hgmma_fp32_traits; + + using Base = Gmem_tile_o_hopper_16bits< + fmha::Hopper_hgmma_fp32_traits, Cta_tile>; + + using Mma_tile = typename Base::Mma_tile; + + template + inline __device__ Gmem_tile_o_hopper(Params const& params, Block_info const& block_info, Shared&&, + int tidx, int cta_row_offset = 0) + : Base(params, block_info, tidx, cta_row_offset) { + static_assert(!std::is_same::value, "Check constructor argument type!"); + } + + // Store data to memory. + template + inline __device__ void store(Accumulators const (&acc)[M][N]) { + int64_t const step_m = 8 * (this->params_o_stride_in_bytes_); + // we assume M = 1. some shortcuts. + static_assert(M == 1); +#pragma unroll + for (int row_idx = 0; row_idx < Base::ROWS_PER_THREAD; ++row_idx) { + if (this->row_ + row_idx * 8 >= this->actual_seqlen_) { + break; + } +#pragma unroll + for (int mma_ni = 0; mma_ni < Base::VALID_MMAS_N - 1; ++mma_ni) { +#pragma unroll + for (int col_idx = 0; col_idx < Base::COLS_PER_THREAD; ++col_idx) { + // 2 denotes as fp32 --> fp16 + float reg0 = acc[0][mma_ni].elt(2 * (col_idx * Base::ROWS_PER_THREAD + row_idx)); + float reg1 = acc[0][mma_ni].elt(2 * (col_idx * Base::ROWS_PER_THREAD + row_idx) + 1); + uint32_t out = fmha::float2_to_half2(reg0, reg1); + + int64_t offset = (int64_t)row_idx * step_m + + (int64_t)(col_idx + mma_ni * Base::COLS_PER_THREAD) * Base::STEP_N; + fmha::stg(this->o_ptr_ + offset, out); + } // col_idx + } // mma_ni + + // The last mma_n may not store full elements back to GMEM. + int mma_ni = Base::VALID_MMAS_N - 1; +#pragma unroll + for (int col_idx = 0; col_idx < Base::VALID_COLS_PER_THREAD_FOR_LAST_MMA; ++col_idx) { + // 2 denotes as fp32 --> fp16 + float reg0 = acc[0][mma_ni].elt(2 * (col_idx * Base::ROWS_PER_THREAD + row_idx)); + float reg1 = acc[0][mma_ni].elt(2 * (col_idx * Base::ROWS_PER_THREAD + row_idx) + 1); + uint32_t out = fmha::float2_to_half2(reg0, reg1); + + int64_t offset = (int64_t)row_idx * step_m + + (int64_t)(col_idx + mma_ni * Base::COLS_PER_THREAD) * Base::STEP_N; + fmha::stg(this->o_ptr_ + offset, out); + } // col_idx + } // row_idx + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o_hopper< + fmha::Hopper_hgmma_bf16_traits, Cta_tile, + 1> // WARPS_K + : public Gmem_tile_o_hopper_16bits< + fmha::Hopper_hgmma_bf16_traits, Cta_tile> { + using Traits = fmha::Hopper_hgmma_bf16_traits; + + using Base = Gmem_tile_o_hopper_16bits< + fmha::Hopper_hgmma_bf16_traits, Cta_tile>; + + using Mma_tile = typename Base::Mma_tile; + + template + inline __device__ Gmem_tile_o_hopper(Params const& params, Block_info const& block_info, Shared&&, + int tidx, int cta_row_offset = 0) + : Base(params, block_info, tidx, cta_row_offset) { + static_assert(!std::is_same::value, "Check constructor argument type!"); + } + + // Store data to memory. + template + inline __device__ void store(Accumulators const (&acc)[M][N]) { + int64_t const step_m = 8 * (this->params_o_stride_in_bytes_); + // we assume M = 1. some shortcuts. + static_assert(M == 1); +#pragma unroll + for (int row_idx = 0; row_idx < Base::ROWS_PER_THREAD; ++row_idx) { + if (this->row_ + row_idx * 8 >= this->actual_seqlen_) { + break; + } +#pragma unroll + for (int mma_ni = 0; mma_ni < Mma_tile::VALID_MMAS_N - 1; ++mma_ni) { +#pragma unroll + for (int col_idx = 0; col_idx < Base::COLS_PER_THREAD; ++col_idx) { + // 2 denotes as fp32 --> bf16 + float reg0 = acc[0][mma_ni].elt(2 * (col_idx * Base::ROWS_PER_THREAD + row_idx)); + float reg1 = acc[0][mma_ni].elt(2 * (col_idx * Base::ROWS_PER_THREAD + row_idx) + 1); + uint32_t out = fmha::float2_to_bf16_x2(reg0, reg1); + + int64_t offset = (int64_t)row_idx * step_m + + (int64_t)(col_idx + mma_ni * Base::COLS_PER_THREAD) * Base::STEP_N; + fmha::stg(this->o_ptr_ + offset, out); + } // row_idx + } // col_idx + + // The last mma_n may not store full elements back to GMEM. + int mma_ni = Base::VALID_MMAS_N - 1; +#pragma unroll + for (int col_idx = 0; col_idx < Base::VALID_COLS_PER_THREAD_FOR_LAST_MMA; ++col_idx) { + // 2 denotes as fp32 --> bf16 + float reg0 = acc[0][mma_ni].elt(2 * (col_idx * Base::ROWS_PER_THREAD + row_idx)); + float reg1 = acc[0][mma_ni].elt(2 * (col_idx * Base::ROWS_PER_THREAD + row_idx) + 1); + uint32_t out = fmha::float2_to_bf16_x2(reg0, reg1); + + int64_t offset = (int64_t)row_idx * step_m + + (int64_t)(col_idx + mma_ni * Base::COLS_PER_THREAD) * Base::STEP_N; + fmha::stg(this->o_ptr_ + offset, out); + } // row_idx + } // mma_ni + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o_hopper< + fmha::Hopper_hgmma_fp16_traits, Cta_tile, + 2> // WARPS_K + : public fmha::v2::Hmma_gmem_tile_o< + fmha::Hopper_hgmma_fp16_traits, Cta_tile, + /*CTAS_PER_HEAD=*/1, + /*BYTES_PER_STG=*/16> { + using Traits = fmha::Hopper_hgmma_fp16_traits; + using Base = fmha::v2::Hmma_gmem_tile_o; + + template + inline __device__ Gmem_tile_o_hopper(Params const& params, Block_info const& block_info, Shared&&, + int tidx, int cta_row_offset = 0) + : Base(params, block_info, tidx, cta_row_offset) { + static_assert(!std::is_same::value, "Check constructor argument type!"); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o_hopper< + fmha::Hopper_hgmma_fp32_traits, Cta_tile, + 2> // WARPS_K + : public fmha::v2::Hmma_gmem_tile_o< + fmha::Hopper_hgmma_fp32_traits, Cta_tile, + /*CTAS_PER_HEAD=*/1, + /*BYTES_PER_STG=*/16> { + using Traits = fmha::Hopper_hgmma_fp32_traits; + using Base = fmha::v2::Hmma_gmem_tile_o; + + template + inline __device__ Gmem_tile_o_hopper(Params const& params, Block_info const& block_info, Shared&&, + int tidx, int cta_row_offset = 0) + : Base(params, block_info, tidx, cta_row_offset) { + static_assert(!std::is_same::value, "Check constructor argument type!"); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o_hopper< + fmha::Hopper_hgmma_bf16_traits, Cta_tile, + 2> // WARPS_K + : public fmha::v2::Hmma_gmem_tile_o< + fmha::Hopper_hgmma_bf16_traits, Cta_tile, + /*CTAS_PER_HEAD=*/1, + /*BYTES_PER_STG=*/16> { + using Traits = fmha::Hopper_hgmma_bf16_traits; + using Base = fmha::v2::Hmma_gmem_tile_o; + + template + inline __device__ Gmem_tile_o_hopper(Params const& params, Block_info const& block_info, Shared&&, + int tidx, int cta_row_offset = 0) + : Base(params, block_info, tidx, cta_row_offset) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o, + Cta_tile, CTAS_PER_HEAD> + : public Gmem_tile_o_hopper< + fmha::Hopper_hgmma_fp16_traits, Cta_tile, + Cta_tile::WARPS_K> { + // The traits class. + using Traits = fmha::Hopper_hgmma_fp16_traits; + + using Base = Gmem_tile_o_hopper< + fmha::Hopper_hgmma_fp16_traits, Cta_tile, + Cta_tile::WARPS_K>; + + // Ctor. + template + inline __device__ Gmem_tile_o(Params const& params, Block_info const& block_info, int tidx, + int cta_row_offset = 0) + : Base(params, block_info, std::nullptr_t{} /* dummy obj */, tidx, cta_row_offset) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o, + Cta_tile, CTAS_PER_HEAD> + : public Gmem_tile_o_hopper< + fmha::Hopper_hgmma_fp32_traits, Cta_tile, + Cta_tile::WARPS_K> { + // The traits class. + using Traits = fmha::Hopper_hgmma_fp32_traits; + + using Base = Gmem_tile_o_hopper< + fmha::Hopper_hgmma_fp32_traits, Cta_tile, + Cta_tile::WARPS_K>; + + // Ctor. + template + inline __device__ Gmem_tile_o(Params const& params, Block_info const& block_info, int tidx, + int cta_row_offset = 0) + : Base(params, block_info, std::nullptr_t{} /* dummy obj */, tidx, cta_row_offset) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o, + Cta_tile, CTAS_PER_HEAD> + : public Gmem_tile_o_hopper< + fmha::Hopper_hgmma_bf16_traits, Cta_tile, + Cta_tile::WARPS_K> { + // The traits class. + using Traits = fmha::Hopper_hgmma_bf16_traits; + + using Base = Gmem_tile_o_hopper< + fmha::Hopper_hgmma_bf16_traits, Cta_tile, + Cta_tile::WARPS_K>; + + // Ctor. + template + inline __device__ Gmem_tile_o(Params const& params, Block_info const& block_info, int tidx, + int cta_row_offset = 0) + : Base(params, block_info, std::nullptr_t{} /* dummy obj */, tidx, cta_row_offset) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o_gmma_32bit_8bit { + static_assert(sizeof(typename Traits::Accumulator_type) == 4); + static_assert(sizeof(typename Traits::C_type) == 1); + // This is for non-splitk GMMA BMM2. + static_assert(Cta_tile::WARPS_K == 1); + // The associated MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of elements per STG. + enum { ELEMENTS_PER_STG = 4 }; + + // The size in bytes of each element. + enum { BYTES_PER_ELEMENT = 1 }; + + // The size of each STG. + enum { BYTES_PER_STG = ELEMENTS_PER_STG * BYTES_PER_ELEMENT }; + + // The size of a row in bytes. + enum { BYTES_PER_ROW = Cta_tile::VALID_N * BYTES_PER_ELEMENT }; + + enum { ROWS = Cta_tile::M }; + + // The number of rows accessed by each thread. + enum { ROWS_PER_THREAD = Mma_tile::M_PER_MMA / 8 / Cta_tile::WARPS_M }; + + static_assert(ROWS_PER_THREAD == 2); + static_assert(ROWS_PER_THREAD == Mma_tile::ROWS_PER_THREAD); + + // The number of columns access by each thread. + // The number of core matrices in N. + enum { COLS_PER_THREAD = Mma_tile::N_PER_MMA / 4 / 2 }; // N_PER_MMA = GMMA_N + + static_assert(COLS_PER_THREAD == Mma_tile::COLS_PER_THREAD / 2); + // Assume there is an even number of core matrices, such that we can pack two + static_assert(COLS_PER_THREAD % 2 == 0); + + // Number of valid N columns. + enum { VALID_N = Cta_tile::VALID_N }; + + // The number of valid columns (stored to GMEM) by each thread. + enum { + VALID_COLS_PER_THREAD_FOR_LAST_MMA = + (VALID_N % Mma_tile::N_PER_MMA) == 0 ? COLS_PER_THREAD : (VALID_N % Mma_tile::N_PER_MMA) / 8 + }; + + enum { VALID_MMAS_N = fmha::Div_up::VALUE }; + + static_assert(VALID_N % 8 == 0, "The valid head dimension needs to be multiple of 8."); + + // The number of N elements must be multiple of 16 in order to pack 4 elements as uint32_t. + enum { PACK_4_ELTS = VALID_N % 16 == 0 }; + + // The number of accumulator held by each thread, per HGMMA instruction. + enum { ELTS_PER_THREAD = ROWS_PER_THREAD * COLS_PER_THREAD * 2 }; + + // Currently, we assume for o matrix, GMMA M shape matches CTA M shape. + static_assert(Mma_tile::M_PER_MMA == Cta_tile::M && + Mma_tile::N_PER_MMA * Mma_tile::MMAS_N == Cta_tile::N, + "Currently, we assume for o matrix, GMMA M/N shape matches CTA M/N shape. "); + + // Step N for one quad (pack 4 elements for a thread, so 16 elements for a quad) + enum { STEP_N = 16 * BYTES_PER_ELEMENT }; + + // The number of head_dimension groups. + enum { N_GROUPS = fmha::Div_up::VALUE }; + + // The head_dimension per group. + enum { N_PER_GROUP = Cta_tile::N / N_GROUPS }; + + static_assert(N_GROUPS * N_PER_GROUP == Cta_tile::N); + + // The head_dimension bytes per group + enum { N_BYTES_PER_GROUP = Cta_tile::N * BYTES_PER_ELEMENT / N_GROUPS }; + + // Pack 2x4 core matrices, use STSMx4 + enum { STSM_PER_MMA = COLS_PER_THREAD / 4 }; + + // The number of registers per 16x16 block + enum { REGS_PER_QUAD = 8 }; + + // Bytes per bank + enum { BYTES_PER_BANK = 16 }; + + // The number of banks in N per group + enum { N_BANKS_PER_GROUP = N_BYTES_PER_GROUP / BYTES_PER_BANK }; + + enum { USE_TMA_STORE = USE_TMA_STORE_ }; + + // Ctor. + template + inline __device__ Gmem_tile_o_gmma_32bit_8bit(Params const& params, Block_info const& block_info, + Shared& shared, int tidx, int cta_row_offset = 0) + : Gmem_tile_o_gmma_32bit_8bit( + params.o_ptr, params.o_stride_in_bytes, block_info, tidx, +#ifdef GENERATE_CUBIN + // Specialized for trt-llm generated cubins only. + params.scale_bmm2_d ? *params.scale_bmm2_d : params.scale_bmm2, +#else + params.scale_bmm2, +#endif + cta_row_offset, 0, + __nvvm_get_smem_pointer(reinterpret_cast( + &shared.smem_o[__shfl_sync(0xffffffff, threadIdx.x / 128, 0)][0])), + ¶ms.tma_desc_o, params.h) { + } + + template + inline __device__ Gmem_tile_o_gmma_32bit_8bit(void* o_ptr, int o_stride_in_bytes, + Block_info const& block_info, int tidx, + uint32_t scale_bmm2, int cta_row_offset = 0, + int mat_offset = 0, uint32_t smem_base = 0, + cudaTmaDesc const* desc_o = nullptr, + int head_num = 0) + : params_o_stride_in_bytes_(o_stride_in_bytes), + actual_seqlen_(block_info.actual_seqlen), + o_ptr_(reinterpret_cast(o_ptr)), + params_scale_bmm2_(scale_bmm2), + smem_base_(smem_base), + desc_o_(desc_o) { + // Decompose the position of the thread into warp/lane. + int warp = tidx / Cta_tile::THREADS_PER_WARP; + int lane = tidx % Cta_tile::THREADS_PER_WARP; + + // int warpgroup_idx = warp / 4; + int warp_idx_within_warpgroup = warp % 4; + + if (USE_TMA_STORE) { + // The head index + bidh_ = block_info.bidh; + // The lane id + lane_ = lane; + // The start row index for current batch + int row_curr_batch = (block_info.bidx - block_info.bidh) / head_num; + // The row index offset of current warp + int row_offset_warp = cta_row_offset + warp_idx_within_warpgroup * (Mma_tile::M_PER_MMA / 4); + // The row index for the current warp + row_tma_ = row_offset_warp + row_curr_batch; + // The valid rows for the current warp. Each warp writes from 0 to 16 rows + num_valid_rows_ = min(Mma_tile::M_PER_MMA / 4, actual_seqlen_ - row_offset_warp); + num_valid_rows_ = max(num_valid_rows_, 0); + // WARNING: Without this line, the predicate will not behavior as expected for unknown reason. + num_valid_rows_ = __shfl_sync(0xffffffff, num_valid_rows_, 0); + // Compute the smem base for STSM + smem_base_ += + warp_idx_within_warpgroup * (Mma_tile::M_PER_MMA / 4) * Cta_tile::N * BYTES_PER_ELEMENT + + (warp / 4) * Mma_tile::M_PER_MMA * Cta_tile::N * BYTES_PER_ELEMENT; + // Compute gmem base for STG in tail case + o_ptr_ += row_tma_ * params_o_stride_in_bytes_ + bidh_ * BYTES_PER_ROW; + } else { + // Compute the position in the sequence (within the CTA for the moment). + int row = warp_idx_within_warpgroup * (Mma_tile::M_PER_MMA / 4) + lane / 4; + // Store the row to update the predicates in load. + row_ = cta_row_offset + row; + // Compute the position of the thread in the row. + col_ = lane % 4 * ELEMENTS_PER_STG; + + // The offset of the 1st row written by the thread. We store the P matrix interleaved. + int64_t row_offset = + (int64_t)row_ * params_o_stride_in_bytes_ + block_info.bidx * BYTES_PER_ROW; + // Finalize the pointer. + o_ptr_ += row_offset + col_ * BYTES_PER_ELEMENT; + } + + // REVIEW: need heads_interleaved option for non-warp-specialized QGMMA + LDGSTS kernels. + // // The row offset in the batched GEMM. For each seq element, we store QKV in that order. + // int64_t row_offset = (int64_t) row_ * params_o_stride_in_bytes_; + // // Add the block index. + // int64_t idx = block_info.bidx; + // if(NUM_MATS > 1) { + // if( HEADS_INTERLEAVED ) { + // idx = block_info.bidx * NUM_MATS + mat_offset; + // } else { + // idx = (block_info.sum_s * NUM_MATS + mat_offset) * block_info.num_heads + + // block_info.bidh; + // } + // } + // // Assemble the final pointer. + // o_ptr_ += row_offset + idx * BYTES_PER_ROW + col * BYTES_PER_ELEMENT; + } + + // Store data to memory. + template + inline __device__ void store(Accumulators const (&acc)[M][N]) { + static_assert(Accumulators::NUM_ELTS == ELTS_PER_THREAD); + static_assert(COLS_PER_THREAD / 2 * ROWS_PER_THREAD * 4 == ELTS_PER_THREAD); + + // we assume M = N = 1. some shortcuts. + static_assert(M == 1); + + if (USE_TMA_STORE) { + static_assert(COLS_PER_THREAD % 4 == 0); + static_assert(ROWS_PER_THREAD == 2); + + int const swizzled_row = (lane_ % 16); + int const swizzled_col = (lane_ / 16); + constexpr int max_swizzle_id = N_BYTES_PER_GROUP / 16; + constexpr int swizzle_row_divider = 128 / N_BYTES_PER_GROUP; + + uint32_t stsm_addr[VALID_MMAS_N][STSM_PER_MMA]; +// Compute swizzled smem address +#pragma unroll + for (int mma_ni = 0; mma_ni < VALID_MMAS_N; ++mma_ni) { +#pragma unroll + for (int ci = 0; ci < STSM_PER_MMA; ++ci) { + int const col_bank = ((mma_ni)*STSM_PER_MMA + ci) * 2 + swizzled_col; + int const di = col_bank / N_BANKS_PER_GROUP; // which N group it belongs to + stsm_addr[mma_ni][ci] = smem_base_ + di * 16 * N_BYTES_PER_GROUP + // group dimension + (((swizzled_row / swizzle_row_divider) % max_swizzle_id) ^ + (col_bank % N_BANKS_PER_GROUP)) * + BYTES_PER_BANK + // column dimension + swizzled_row * N_BYTES_PER_GROUP; // row dimension + } + } + +#pragma unroll + for (int mma_ni = 0; mma_ni < VALID_MMAS_N; ++mma_ni) { +#pragma unroll + for (int ci = 0; ci < STSM_PER_MMA; ++ci) { + uint32_t dst[4]; + uint4 src[4]; + + /* + * Each STSMx4 produces a 16x32 block, that is 2x4 core matrices + * ----------------- + * | 0 | 2 | 4 | 6 | + * ----------------- + * | 1 | 3 | 5 | 7 | + * ----------------- + * + * Consider the entire warp, src[0] holds matrices 0,2; src[1] holds matrices 1,3; + * src[3] holds matrices 4,6; src[4] holds matrices 5,7. + */ + src[0].x = acc[0][mma_ni].reg((ci * 2 + 0) * REGS_PER_QUAD + 0); + src[0].y = acc[0][mma_ni].reg((ci * 2 + 0) * REGS_PER_QUAD + 4); + src[0].z = acc[0][mma_ni].reg((ci * 2 + 0) * REGS_PER_QUAD + 1); + src[0].w = acc[0][mma_ni].reg((ci * 2 + 0) * REGS_PER_QUAD + 5); + + src[1].x = acc[0][mma_ni].reg((ci * 2 + 0) * REGS_PER_QUAD + 2); + src[1].y = acc[0][mma_ni].reg((ci * 2 + 0) * REGS_PER_QUAD + 6); + src[1].z = acc[0][mma_ni].reg((ci * 2 + 0) * REGS_PER_QUAD + 3); + src[1].w = acc[0][mma_ni].reg((ci * 2 + 0) * REGS_PER_QUAD + 7); + + src[2].x = acc[0][mma_ni].reg((ci * 2 + 1) * REGS_PER_QUAD + 0); + src[2].y = acc[0][mma_ni].reg((ci * 2 + 1) * REGS_PER_QUAD + 4); + src[2].z = acc[0][mma_ni].reg((ci * 2 + 1) * REGS_PER_QUAD + 1); + src[2].w = acc[0][mma_ni].reg((ci * 2 + 1) * REGS_PER_QUAD + 5); + + src[3].x = acc[0][mma_ni].reg((ci * 2 + 1) * REGS_PER_QUAD + 2); + src[3].y = acc[0][mma_ni].reg((ci * 2 + 1) * REGS_PER_QUAD + 6); + src[3].z = acc[0][mma_ni].reg((ci * 2 + 1) * REGS_PER_QUAD + 3); + src[3].w = acc[0][mma_ni].reg((ci * 2 + 1) * REGS_PER_QUAD + 7); + + using Src_type = typename Traits::Accumulator_type; + using Dst_type = typename Traits::C_type; +// Packs the 32bit values to 8bit. +// Depending on the type, applies extra scaling with parameter scale_bmm2. +#pragma unroll + for (int i = 0; i < 4; ++i) { +#ifdef UNIFIED_EPILOGUE_SCALE + dst[i] = Acc_packer::run(this, src[i]); +#else + dst[i] = Acc_packer::run(this, src[i]); +#endif + } + stsm(stsm_addr[mma_ni][ci], *reinterpret_cast(&dst[0])); + } + } + + // TODO: Interleave STSM and UTMASTG of two N groups + constexpr int MAX_ROWS_PER_WARP = Mma_tile::M_PER_MMA / 4; + if (num_valid_rows_ == MAX_ROWS_PER_WARP) { + fence_view_async_shared(); +#pragma unroll + for (int di = 0; di < N_GROUPS; ++di) { + const int32_t coords[3] = {di * N_PER_GROUP, bidh_, row_tma_}; + fmha::utmastg<3, fmha::cudaTmaDescType::TILED>( + desc_o_, smem_base_ + di * 16 * N_BYTES_PER_GROUP, coords); + } + tmastg_arrive(); + tmastg_wait(); + } else if (num_valid_rows_ > 0) { + // Use LDS.64 + STG.64 to store num_valid_rows_ x N tile + constexpr int BYTES_PER_THREAD = 8; + static_assert((VALID_N % BYTES_PER_THREAD) == 0, "VALID_N must be divided by 8 for STG.64"); + // Number of valid rows + int row_size = num_valid_rows_; + // Number of threads per row. Each thread read/write 8B (8 elements). + constexpr int THREADS_PER_ROW = N_BYTES_PER_GROUP / 8; + // Number of rows read/written by a warp + static_assert(Cta_tile::THREADS_PER_WARP % THREADS_PER_ROW == 0, + "A warp must reads full rows"); + constexpr int ROWS_PER_WARP = Cta_tile::THREADS_PER_WARP / THREADS_PER_ROW; + // GMEM stride in M dimension + int64_t const step_m = (this->params_o_stride_in_bytes_); + // Initial column index + int const ci = lane_ % THREADS_PER_ROW; + int const bank_idx = (ci * BYTES_PER_THREAD) / BYTES_PER_BANK; + int const bank_offset = (ci * BYTES_PER_THREAD) % BYTES_PER_BANK; + +#pragma unroll + for (int di = 0; di < N_GROUPS; ++di) { + // Detect GMEM index out of bound + if ((di * N_BYTES_PER_GROUP + ci * BYTES_PER_THREAD) >= BYTES_PER_ROW) { + break; + } +#pragma unroll + for (int ri = lane_ / THREADS_PER_ROW; ri < row_size; ri += ROWS_PER_WARP) { + // Create the swizzled offset + uint32_t smem_offset = + di * 16 * N_BYTES_PER_GROUP + ri * N_BYTES_PER_GROUP + + (((ri / swizzle_row_divider) % max_swizzle_id) ^ bank_idx) * BYTES_PER_BANK + + bank_offset; + uint2 buffer; + fmha::lds(buffer, smem_base_ + smem_offset); + int64_t gmem_offset = + (int64_t)ri * step_m + di * N_BYTES_PER_GROUP + ci * BYTES_PER_THREAD; + fmha::stg(o_ptr_ + gmem_offset, buffer); + } + } + } + } else { + int64_t const step_m = 8 * (this->params_o_stride_in_bytes_); + +#pragma unroll + for (int ri = 0; ri < ROWS_PER_THREAD; ++ri) { + if (row_ + ri * 8 >= actual_seqlen_) { + break; + } + +#pragma unroll + for (int mma_ni = 0; mma_ni < VALID_MMAS_N - 1; ++mma_ni) { +// Iterate over 16 columns to pack 4 values per thread. +#pragma unroll + for (int ci = 0; ci < COLS_PER_THREAD / 2; ++ci) { + // Assuming EVEN,EVEN,ODD,ODD column pattern due to packing of V. + uint4 src; + src.x = acc[0][mma_ni].reg(((2 * ci + 0) * ROWS_PER_THREAD + ri) * 2 + 0); // 0 + src.y = acc[0][mma_ni].reg(((2 * ci + 1) * ROWS_PER_THREAD + ri) * 2 + 0); // 4 + src.z = acc[0][mma_ni].reg(((2 * ci + 0) * ROWS_PER_THREAD + ri) * 2 + 1); // 1 + src.w = acc[0][mma_ni].reg(((2 * ci + 1) * ROWS_PER_THREAD + ri) * 2 + 1); // 5 + + using Src_type = typename Traits::Accumulator_type; + using Dst_type = typename Traits::C_type; + // Packs the 32bit values to 8bit. + // Depending on the type, applies extra scaling with parameter scale_bmm2. +#ifdef UNIFIED_EPILOGUE_SCALE + uint32_t dst = Acc_packer::run(this, src); +#else + uint32_t dst = Acc_packer::run(this, src); +#endif + + int64_t offset = + (int64_t)ri * step_m + (int64_t)(ci + mma_ni * COLS_PER_THREAD / 2) * STEP_N; + fmha::stg(o_ptr_ + offset, dst); + } // ci + } // mma_ni + + if constexpr (PACK_4_ELTS) { + // The last mma_n may not store full elements back to GMEM. + int mma_ni = VALID_MMAS_N - 1; +// Iterate over 16 columns to pack 4 values per thread. +#pragma unroll + for (int ci = 0; ci < VALID_COLS_PER_THREAD_FOR_LAST_MMA / 2; ++ci) { + // Assuming EVEN,EVEN,ODD,ODD column pattern due to packing of V. + uint4 src; + src.x = acc[0][mma_ni].reg(((2 * ci + 0) * ROWS_PER_THREAD + ri) * 2 + 0); // 0 + src.y = acc[0][mma_ni].reg(((2 * ci + 1) * ROWS_PER_THREAD + ri) * 2 + 0); // 4 + src.z = acc[0][mma_ni].reg(((2 * ci + 0) * ROWS_PER_THREAD + ri) * 2 + 1); // 1 + src.w = acc[0][mma_ni].reg(((2 * ci + 1) * ROWS_PER_THREAD + ri) * 2 + 1); // 5 + + using Src_type = typename Traits::Accumulator_type; + using Dst_type = typename Traits::C_type; + // Packs the 32bit values to 8bit. + // Depending on the type, applies extra scaling with parameter scale_bmm2. +#ifdef UNIFIED_EPILOGUE_SCALE + uint32_t dst = Acc_packer::run(this, src); +#else + uint32_t dst = Acc_packer::run(this, src); +#endif + + int64_t offset = + (int64_t)ri * step_m + (int64_t)(ci + mma_ni * COLS_PER_THREAD / 2) * STEP_N; + fmha::stg(o_ptr_ + offset, dst); + } // ci + } else { + // The last mma_n may not store full elements back to GMEM. + int mma_ni = VALID_MMAS_N - 1; +// Iterate over 16 columns to pack 4 values per thread (2 uint2). +#pragma unroll + for (int ci = 0; ci < fmha::Div_up::VALUE; ++ci) { + // Assuming EVEN,EVEN,ODD,ODD column pattern due to packing of V. + uint2 src0, src1; + src0.x = acc[0][mma_ni].reg(((2 * ci + 0) * ROWS_PER_THREAD + ri) * 2 + 0); // 0 + src0.y = acc[0][mma_ni].reg(((2 * ci + 1) * ROWS_PER_THREAD + ri) * 2 + 0); // 4 + src1.x = acc[0][mma_ni].reg(((2 * ci + 0) * ROWS_PER_THREAD + ri) * 2 + 1); // 1 + src1.y = acc[0][mma_ni].reg(((2 * ci + 1) * ROWS_PER_THREAD + ri) * 2 + 1); // 5 + + using Src_type = typename Traits::Accumulator_type; + using Dst_type = typename Traits::C_type; +#ifdef UNIFIED_EPILOGUE_SCALE + uint16_t dst0 = Acc_packer::run(this, src0); + uint16_t dst1 = Acc_packer::run(this, src1); +#else + uint16_t dst0 = Acc_packer::run(this, src0); + uint16_t dst1 = Acc_packer::run(this, src1); +#endif + + // 4 elements per thread, so 16 elements per loop. + int col_idx = (ci + mma_ni * COLS_PER_THREAD / 2) * 16; + + int64_t offset = (int64_t)ri * step_m + (int64_t)(col_idx)*BYTES_PER_ELEMENT; + + if (col_idx + col_ < VALID_N) { + fmha::stg(o_ptr_ + offset, dst0); + } + + if (col_idx + col_ + 2 < VALID_N) { + fmha::stg(o_ptr_ + offset + 2 * BYTES_PER_ELEMENT, dst1); + } + } // ci + } + } // ri + } + } + + // Move to the next location. + inline __device__ void move() { + row_ += ROWS; + o_ptr_ += (int64_t)ROWS * params_o_stride_in_bytes_; + } + + // The stride between rows for the QKV matrice. + int64_t params_o_stride_in_bytes_; + // The pointer. + char* o_ptr_; + // Is the thread active for the last STG? + int is_active_for_last_stg_; + + // The row, col loaded by this thread. + int row_, col_; + // The length of the sequence loaded by that CTA. + int actual_seqlen_; + + // Scaling factor; this usually means QKV descale factor in actuality + uint32_t params_scale_bmm2_; + + // Smem buffer for TMASTG + uint32_t smem_base_; + cudaTmaDesc const* desc_o_; + + int lane_; + int row_tma_; + int num_valid_rows_; + int bidh_; + + bool const params_enable_i2f_trick_ = false; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o_hopper_32bit_8bit {}; + +template +struct Gmem_tile_o_hopper_32bit_8bit + : public Gmem_tile_o_gmma_32bit_8bit { + // The Base class. + using Base = Gmem_tile_o_gmma_32bit_8bit; + + // Ctor. + template + inline __device__ Gmem_tile_o_hopper_32bit_8bit(Params const& params, + Block_info const& block_info, Shared& shared, + int tidx, int cta_row_offset = 0) + : Base(params, block_info, shared, tidx, cta_row_offset) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o_hopper_32bit_8bit + : public Gmem_tile_o_8bit { + // The Base class. + using Base = Gmem_tile_o_8bit; + + // Ctor. + template + inline __device__ Gmem_tile_o_hopper_32bit_8bit(Params const& params, + Block_info const& block_info, Shared& shared, + int tidx, int cta_row_offset = 0) + : Base(params, block_info, shared, tidx, cta_row_offset) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o_hopper< + fmha::Hopper_qgmma_fp8_fp32_traits, Cta_tile, + CTAS_PER_HEAD> + : public Gmem_tile_o_hopper_32bit_8bit< + fmha::Hopper_qgmma_fp8_fp32_traits, + Cta_tile, Cta_tile::WARPS_K> { + // The traits class. + using Traits = fmha::Hopper_qgmma_fp8_fp32_traits; + + using Base = Gmem_tile_o_hopper_32bit_8bit< + fmha::Hopper_qgmma_fp8_fp32_traits, Cta_tile, + Cta_tile::WARPS_K>; + + // Ctor. + template + inline __device__ Gmem_tile_o_hopper(Params const& params, Block_info const& block_info, + Shared& shared, int tidx, int cta_row_offset = 0) + : Base(params, block_info, shared, tidx, cta_row_offset) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o< + fmha::Hopper_igmma_int8_int32_traits, Cta_tile, + CTAS_PER_HEAD> + : public Gmem_tile_o_hopper_32bit_8bit< + fmha::Hopper_igmma_int8_int32_traits, + Cta_tile, Cta_tile::WARPS_K> { + // The traits class. + using Traits = fmha::Hopper_igmma_int8_int32_traits; + + using Base = Gmem_tile_o_hopper_32bit_8bit< + fmha::Hopper_igmma_int8_int32_traits, Cta_tile, + Cta_tile::WARPS_K>; + + // Ctor. + template + inline __device__ Gmem_tile_o(Params const& params, Block_info const& block_info, int tidx, + int cta_row_offset = 0) + : Base(params, block_info, std::nullptr_t{} /* dummy obj */, tidx, cta_row_offset) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o_qgmma_fp32_16bits { + // The associated MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of elements per STG. + enum { ELEMENTS_PER_STG = 2 }; + + // The size in bytes of each element. + enum { BYTES_PER_ELEMENT = 2 }; + + // The size of each STG. + enum { BYTES_PER_STG = ELEMENTS_PER_STG * BYTES_PER_ELEMENT }; + + // The size of a row in bytes. + enum { BYTES_PER_ROW = Cta_tile::VALID_N * BYTES_PER_ELEMENT }; + + // The number of rows accessed by each thread. + enum { ROWS_PER_THREAD = Mma_tile::M_PER_MMA / 8 / Cta_tile::WARPS_PER_CTA }; + + enum { ROWS = Cta_tile::M }; + + // The number of columns access by each thread. + // Note there are 2 elements per reg. + enum { COLS_PER_THREAD = Mma_tile::N_PER_MMA / 4 / 2 }; + + // The number of valid columns (stored to GMEM) by each thread. + enum { + VALID_COLS_PER_THREAD_FOR_LAST_MMA = (Cta_tile::VALID_N % Mma_tile::N_PER_MMA) == 0 + ? COLS_PER_THREAD + : (Cta_tile::VALID_N % Mma_tile::N_PER_MMA) / 8 + }; + + enum { VALID_MMAS_N = fmha::Div_up::VALUE }; + + static_assert(Cta_tile::VALID_N % 8 == 0, "The valid head dimension needs to be multiple of 8."); + + // The number of accumulator held by each thread, per HGMMA instruction. + enum { ELTS_PER_THREAD = ROWS_PER_THREAD * COLS_PER_THREAD }; + + // Currently, we assume for o matrix, GMMA M/N shape matches CTA M/N shape. + static_assert(Mma_tile::M_PER_MMA == Cta_tile::M && + Mma_tile::N_PER_MMA * Mma_tile::MMAS_N == Cta_tile::N, + "Currently, we assume for o matrix, GMMA M shape matches CTA M shape. "); + + // Step N for one quad + enum { STEP_N = 8 * BYTES_PER_ELEMENT }; + + // Ctor. + template + inline __device__ Gmem_tile_o_qgmma_fp32_16bits(Params const& params, + Block_info const& block_info, Shared&&, int tidx, + int cta_row_offset = 0) + : params_o_stride_in_bytes_(params.o_stride_in_bytes), + params_scale_bmm2_( +#ifdef GENERATE_CUBIN + // Specialized for trt-llm generated cubins only. + params.scale_bmm2_d ? *params.scale_bmm2_d : params.scale_bmm2 +#else + params.scale_bmm2 +#endif + ), + actual_seqlen_(block_info.actual_seqlen), + o_ptr_(reinterpret_cast(params.o_ptr)) { + static_assert(!std::is_same::value, "Check constructor argument type!"); + // Decompose the position of the thread into warp/lane. + int warp = tidx / Cta_tile::THREADS_PER_WARP; + int lane = tidx % Cta_tile::THREADS_PER_WARP; + + int warp_idx_within_warpgroup = warp % 4; + + // Compute the position in the sequence (within the CTA for the moment). + int row = warp_idx_within_warpgroup * (Mma_tile::M_PER_MMA / 4) + lane / 4; + // Store the row to update the predicates in load. + row_ = cta_row_offset + row; + // Compute the position of the thread in the row. + // echo loop handles 2 cores, so x2 (this is the difference to Gmem_tile_o_hopper_16bits) + int col = lane % 4 * ELEMENTS_PER_STG * 2; + + // The offset of the 1st row written by the thread. We store the P matrix interleaved. + int64_t row_offset = + (int64_t)row_ * params_o_stride_in_bytes_ + block_info.bidx * BYTES_PER_ROW; + // Finalize the pointer. + o_ptr_ += row_offset + col * BYTES_PER_ELEMENT; + } + + // Store data to memory. + template + inline __device__ void store(Accumulators const (&acc)[M][N]) { + int64_t const step_m = 8 * params_o_stride_in_bytes_; +#ifdef UNIFIED_EPILOGUE_SCALE + constexpr bool Scale = false; +#else + constexpr bool Scale = true; +#endif +#define STORE_COLUMNS() \ + { \ + /* we assume M = 1. some shortcuts. */ \ + static_assert(M == 1); \ + uint4 _src = { \ + .x = acc[0][mma_ni].reg(((ci + 0) * ROWS_PER_THREAD + ri) * 2), \ + .y = acc[0][mma_ni].reg(((ci + 1) * ROWS_PER_THREAD + ri) * 2), \ + .z = acc[0][mma_ni].reg(((ci + 0) * ROWS_PER_THREAD + ri) * 2 + 1), \ + .w = acc[0][mma_ni].reg(((ci + 1) * ROWS_PER_THREAD + ri) * 2 + 1), \ + }; \ + uint2 _dst = Acc_packer::run(this, _src); \ + int64_t _offset = (int64_t)ri * step_m + (int64_t)(ci + mma_ni * COLS_PER_THREAD) * STEP_N; \ + fmha::stg(o_ptr_ + _offset, _dst); \ + } + +#pragma unroll + for (int ri = 0; ri < ROWS_PER_THREAD; ri++) { + if (row_ + ri * 8 >= actual_seqlen_) { + break; + } +#pragma unroll + for (int mma_ni = 0; mma_ni < VALID_MMAS_N - 1; ++mma_ni) { +#pragma unroll + for (int ci = 0; ci < COLS_PER_THREAD; ci += 2) { + STORE_COLUMNS() + } + } + // The last mma_n may not store full elements back to GMEM. + int mma_ni = VALID_MMAS_N - 1; +#pragma unroll + for (int ci = 0; ci < VALID_COLS_PER_THREAD_FOR_LAST_MMA; ci += 2) { + STORE_COLUMNS() + } + } + } + + // Move to the next location. + inline __device__ void move() { + row_ += ROWS; + o_ptr_ += (int64_t)ROWS * params_o_stride_in_bytes_; + } + + // The stride between rows for the QKV matrice. + int64_t params_o_stride_in_bytes_; + // Scaling factor; this usually means QKV descale factor in actuality + uint32_t params_scale_bmm2_; + // The pointer. + char* o_ptr_; + // The row loaded by this thread. + int row_; + // The length of the sequence loaded by that CTA. + int actual_seqlen_; +}; + +} // namespace v2 + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/hopper/gmem_tile_qkv_packed.h b/csrc/fmha_v2/fmha/hopper/gmem_tile_qkv_packed.h new file mode 100644 index 0000000000..5ee0ac50d1 --- /dev/null +++ b/csrc/fmha_v2/fmha/hopper/gmem_tile_qkv_packed.h @@ -0,0 +1,146 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once +#include + +namespace fmha { +namespace v2 { + +template < + // The instruction traits. + typename Traits, + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The number of bits per element. + int BITS_PER_ELEMENT, + // The number of rows of Q, K or V loaded by this tile. + int ROWS_, + // The number of columns. + int COLS, + // Do we use LDGSTS? + bool USE_LDGSTS_, + // Are attention heads interleaved? + bool HEADS_INTERLEAVED, + // The number of matrices + int NUM_MATS = 3> +struct Gmem_tile_tma_qkv { + // The size of each LDG. + enum { BYTES_PER_LDG = 16 }; + + // The size of a row in bytes. + enum { BYTES_PER_ROW = COLS * BITS_PER_ELEMENT / 8 }; + + // The number of threads to load a "row" of the matrix. + enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDG }; + + // The number of "rows" loaded per LDG. + enum { ROWS_PER_LDG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + + // The number of rows. + enum { ROWS = ROWS_ }; + + // The number of LDGs needed to load a chunk of the Q matrix. + enum { LDGS = fmha::Div_up::VALUE }; + + // The number of predicate registers. + enum { PRED_REGS = fmha::Compute_number_of_pred_regs::VALUE }; + + // Is it Hopper? + enum { + IS_HOPPER = std::is_same::value == true + }; + + // Make sure we use a single register to store predicates. Do not throw for Hopper for now. + static_assert(!USE_LDGSTS_ || PRED_REGS == 1 || IS_HOPPER, ""); + + // We do not use LDGSTS (for the moment). + enum { USE_LDGSTS = USE_LDGSTS_ }; + + // TMA DIMS, hard coded for now + enum { TMA_DIMS = 3 }; + + // TMA DESC type, hard coded for now + static constexpr fmha::cudaTmaDescType TMA_DESC_TYPE = fmha::cudaTmaDescType::TILED; + + // Ctor. + template + inline __device__ Gmem_tile_tma_qkv(Params const& params, cudaTmaDesc const* p_desc, + int qkv_offset, Block_info const& block_info, int tidx, + int cta_row_offset = 0) + // in PACKED_QKV, q_stride = k_stride = v_stride + : params_qkv_stride_in_bytes_(params.q_stride_in_bytes), + actual_seqlen_(block_info.actual_seqlen), + qkv_ptr_(reinterpret_cast(params.qkv_ptr)), + p_desc_(p_desc) { + // Both MQA and GQA will use non HEADS_INTERLEAVED layout + if (params.h_kv < params.h) { + // QKV layout [b, s, [q_hd, k_h'd, v_h'd]] + int const hi = block_info.bidh; + int const hi_kv = block_info.bidh / (params.h / params.h_kv); + if (qkv_offset == 0) { // Q tensor + coord[0] = hi * params.d; + } else if (qkv_offset == 1) { // K tensor + coord[0] = params.h * params.d + hi_kv * params.d; + } else if (qkv_offset == 2) { // V tensor + coord[0] = params.h * params.d + params.h_kv * params.d + hi_kv * params.d; + } + } else { + coord[0] = qkv_offset * params.d + block_info.bidh * params.d * 3; + } + // coord[1] = block_info.bidb * params.s; // should be params.s * batch_idx + // coord[1] do not need to be adjusted per batch. + // since the gmem_ptr in tma desc is set per batch and already adjusted. + coord[1] = block_info.sum_s; + coord[2] = 0; + } + + // Store data to shared memory. + template + inline __device__ void commit(Smem_tile& smem_tile) {} + + // Load data from memory. + template + inline __device__ void load(Smem_tile& smem_tile) { + smem_tile.template store(p_desc_, coord); + } + + // Store data to memory. + inline __device__ void store(uint4 const (&data)[LDGS]) {} + + // Move the pointer to the next location. + // only needed by matrix Q. + inline __device__ void move() { + // coord[1] is incremented by STEP size. + coord[1] += ROWS; + } + + // The stride between rows for the QKV matrice. + int64_t params_qkv_stride_in_bytes_; + // The pointer. + char* qkv_ptr_; + // The register to store predicates. + uint32_t preds_[PRED_REGS]; + // The fetch registers. + uint4 fetch_[LDGS]; + // Keep track of the row the thread is processing as we move the tile. + int row_; + // The sequence length. + int actual_seqlen_; + // tma descriptor + cudaTmaDesc const* p_desc_; + // coord use by TMA. For now hard code to 3D. + int32_t coord[3]; +}; + +} // namespace v2 +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/hopper/gmma_descriptor.h b/csrc/fmha_v2/fmha/hopper/gmma_descriptor.h new file mode 100644 index 0000000000..8b4129e343 --- /dev/null +++ b/csrc/fmha_v2/fmha/hopper/gmma_descriptor.h @@ -0,0 +1,547 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +namespace fmha { +//////////////////////////////////////////////////////////////////////////////////////////////////// +// whether transpose is applied on the smem before GMMA math execution +// if TN, notrans is applied to both A and B. as GMMA expects the data +// to be in TN format. +// if NT, trans is applied to both A and B. +//////////////////////////////////////////////////////////////////////////////////////////////////// +enum class Gmma_descriptor_transpose { TRANS, NOTRANS }; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Gmma descriptor mode +// 2 bits to specify the descriptor mode. +//////////////////////////////////////////////////////////////////////////////////////////////////// +enum class Gmma_descriptor_mode { SWIZZLE_NONE = 0, SWIZZLE_128B, SWIZZLE_64B, SWIZZLE_32B }; +constexpr uint32_t GMMA_DESCRIPTOR_MODE_BITS = 2; +constexpr uint32_t GMMA_DESCRIPTOR_MODE_SHIFT = 62; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// number of descriptor per GMMA group to be actually allocated per kblock +//////////////////////////////////////////////////////////////////////////////////////////////////// +enum class Gmma_descriptor_size { + ONE, + TWO, // not yet implemented. might be needed for 64xNxK tile size. + // as many as needed (kblock / gmma_k). we may not prefer to use this as we may run out of UR + // budget + ALL +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// a single desc that has the info and bits +//////////////////////////////////////////////////////////////////////////////////////////////////// +template +class Single_descriptor { + public: + // trans mode + static constexpr Gmma_descriptor_transpose TRANS_MODE = Gmma_trans; + + // set the single desc + inline __device__ void set(uint64_t const& desc_) { desc = desc_; } + + // get the single desc + inline __device__ uint64_t get() const { return desc; } + + private: + // the descriptor, each of 64 bit + uint64_t desc; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// for a +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class Gmma_descriptor_a { + public: + // The type of the Single Descriptor + using Single_desc = Single_descriptor; + + // Transpose Mode + static constexpr Gmma_descriptor_transpose TRANS_MODE = Gmma_trans; + + // The number of descriptors per 64xNblockxKblock. + static constexpr Gmma_descriptor_size GMMA_DESC_SIZE_PER_GROUP = Gmma_vector_size; + + // Currently the number of descriptors per 64xNblockxKblock is always One + // Historically we have supported more descriptors. But that has proven to + // be less performant as it consumes too many uniform registers. + // During the process of refactoring we have decided to only support allocating + // one desc per 64xNblockxKblock. If needed in the future, we can support + // more desc. + static_assert(Gmma_vector_size == Gmma_descriptor_size::ONE, + "Currently, only Mblock/64 desc is allocated per kgroup\n"); + + // Interleaved Mode is currently not supported. + // static_assert to avoid accidentally instantiate it. + static_assert(Gmma_mode != Gmma_descriptor_mode::SWIZZLE_NONE, + "Currently, SWIZZLE_NONE mode is not implemented. \n"); + + // byte per leading dim (row if TN, column is NT) must be 128 + enum { BYTES_PER_LEADING_DIM = 128 }; + + // bytes per element + enum { BYTES_PER_ELEMENT = BITS_PER_ELEMENT / 8 }; + + // the number of descriptors per kblock is related to GMMA shape and kblock size + enum { + NUM_DESCRIPTORS = (Gmma_vector_size == Gmma_descriptor_size::ALL) ? Cta_tile::K / GMMA_K : 1 + }; + + // the number of descriptors per 128 byte in k dimension (leading dim) + // NUM_DESCRIPTORS_PER_128B_IN_K is really only needed if leading dim is K + enum { + NUM_DESCRIPTORS_PER_128B_IN_K = (Gmma_mode == Gmma_descriptor_mode::SWIZZLE_128B && + Gmma_trans == Gmma_descriptor_transpose::NOTRANS) + ? BYTES_PER_LEADING_DIM / ((GMMA_K * BITS_PER_ELEMENT) / 8) + : NUM_DESCRIPTORS + }; + + static constexpr uint32_t BYTES_PER_GMMA_K = GMMA_K * BITS_PER_ELEMENT / 8; // 32B + + // the distance between neighboring descriptors + static constexpr uint32_t BYTES_PER_DESC = + Gmma_vector_size == Gmma_descriptor_size::ALL ? 0 + : Gmma_trans == Gmma_descriptor_transpose::TRANS + ? Gmma_mode == Gmma_descriptor_mode::SWIZZLE_128B ? GMMA_K * BYTES_PER_LEADING_DIM + : Gmma_mode == Gmma_descriptor_mode::SWIZZLE_64B ? (GMMA_K / 2) * BYTES_PER_LEADING_DIM + : Gmma_mode == Gmma_descriptor_mode::SWIZZLE_32B ? (GMMA_K / 4) * BYTES_PER_LEADING_DIM + : 0 + : Gmma_trans == Gmma_descriptor_transpose::NOTRANS + ? Gmma_mode == Gmma_descriptor_mode::SWIZZLE_128B || + Gmma_mode == Gmma_descriptor_mode::SWIZZLE_64B + ? BYTES_PER_GMMA_K // 32B + : Gmma_mode == Gmma_descriptor_mode::SWIZZLE_32B ? Cta_tile::M * BYTES_PER_GMMA_K + : 0 + : 0; + + // the distance between neighboring desc without 4LSB + static constexpr uint32_t BYTES_PER_DESC_NO_4LSB = BYTES_PER_DESC >> 4; + + // the distance to travel back from the last desc to the first desc within a group + enum { BYTES_DESC_INC_BOUNDARY_NO_4LSB = BYTES_PER_DESC_NO_4LSB * (Cta_tile::K / GMMA_K - 1) }; + + // set GMMA descriptor mode bits. + static constexpr uint64_t DESCRIPTOR_MODE_IN_BIT_LOCATION = + (static_cast(Gmma_mode) & ((1u << GMMA_DESCRIPTOR_MODE_BITS) - 1)) + << GMMA_DESCRIPTOR_MODE_SHIFT; + + // stride byte offset, bit 32-45, 4LSB not included + // each row is always of 128 byte. 8 rows always. + // divide by 16 since the 4 LSB is not included + static constexpr uint64_t STRIDE_BYTE_OFFSET = + BYTES_PER_LEADING_DIM * + ((Gmma_mode == Gmma_descriptor_mode::SWIZZLE_128B) ? 8 + : Gmma_mode == Gmma_descriptor_mode::SWIZZLE_64B ? 4 + : 2) / + 16; + // shift 32 bit + static constexpr uint64_t STRIDE_BYTE_OFFSET_IN_BIT_LOCATION = STRIDE_BYTE_OFFSET << 32; + + // leading byte offset, bit 16-29, 4LSB not included + // each row is still 128 byte. + // divide by 16 since the 4 LSB is not included + // for A matrix of TN, and the way we reshape the matrix, LEADING_BYTE_OFFSET is never non-zero + // in the future with different GMMA shape, this might be needed + static constexpr bool LEADING_BYTE_OFFSET_NEEDED = false; + + // the leading byte offset if needed 4LSB not included + static constexpr uint64_t LEADING_BYTE_OFFSET = + Gmma_mode == Gmma_descriptor_mode::SWIZZLE_32B + ? BYTES_PER_LEADING_DIM / 16 + : BYTES_PER_LEADING_DIM * + ((Gmma_trans == Gmma_descriptor_transpose::TRANS) ? Cta_tile::K : Cta_tile::M) / 16; + // shift 16 bit + static constexpr uint64_t LEADING_BYTE_OFFSET_IN_BIT_LOCATION = + LEADING_BYTE_OFFSET_NEEDED ? LEADING_BYTE_OFFSET << 16 : 0; + + // ctor + inline __device__ Gmma_descriptor_a() { +#pragma unroll + for (int desc_idx = 0; desc_idx < NUM_DESCRIPTORS; ++desc_idx) { + desc[desc_idx] = 0; + } + +// set bit 62-63 to 1 for SWIZZLE_128B format +// set bit 62-63 to 2 for SWIZZLE_64B format +#pragma unroll + for (int desc_idx = 0; desc_idx < NUM_DESCRIPTORS; ++desc_idx) { + desc[desc_idx] |= DESCRIPTOR_MODE_IN_BIT_LOCATION; + } + +// stride byte offset, bit 32-45, 4LSB not included +#pragma unroll + for (int desc_idx = 0; desc_idx < NUM_DESCRIPTORS; ++desc_idx) { + desc[desc_idx] |= STRIDE_BYTE_OFFSET_IN_BIT_LOCATION; + } + + // leading byte offset, bit 16-29, 4LSB not included + if (LEADING_BYTE_OFFSET_NEEDED) { +#pragma unroll + for (int desc_idx = 0; desc_idx < NUM_DESCRIPTORS; ++desc_idx) { + desc[desc_idx] |= LEADING_BYTE_OFFSET_IN_BIT_LOCATION; + } + } + } + + // update the descriptor based on smem address. Should be called once from prologue. + inline __device__ void set_smem_pointer(uint32_t smem_nvvm_pointer) { + // uint32_t smem_nvvm_pointer = get_smem_pointer(smem); + uint64_t smem_address_bit = static_cast(smem_nvvm_pointer); + + // set base offset, bit 49-61 + uint64_t offset = (smem_address_bit / BYTES_PER_LEADING_DIM) % + ((Gmma_mode == Gmma_descriptor_mode::SWIZZLE_128B) ? 8 + : Gmma_mode == Gmma_descriptor_mode::SWIZZLE_64B ? 4 + : 2); + uint64_t offset_in_bit_location = offset << 49; +#pragma unroll + for (int desc_idx = 0; desc_idx < NUM_DESCRIPTORS; ++desc_idx) { + desc[desc_idx] |= offset_in_bit_location; + } + +// start_address, bit 0-13, 4LSB not included (so grab bit 4-17) +// the only bits that is different for each desc of the same obj +#pragma unroll + for (int desc_idx = 0; desc_idx < NUM_DESCRIPTORS; ++desc_idx) { + // for fp16, desc_idx_in_128B should range from 0 to 3 + int desc_idx_in_128B = desc_idx % NUM_DESCRIPTORS_PER_128B_IN_K; + int desc_idx_over_128B = desc_idx / NUM_DESCRIPTORS_PER_128B_IN_K; + + uint64_t smem_address_bit_in_bit_location = + (smem_address_bit + ((GMMA_K * BITS_PER_ELEMENT) / 8) * desc_idx_in_128B + + Cta_tile::M * BYTES_PER_LEADING_DIM * desc_idx_over_128B) + << 46; + + smem_address_bit_in_bit_location = smem_address_bit_in_bit_location >> 50; + desc[desc_idx] |= smem_address_bit_in_bit_location; + } + } + + // get a single desc from the desc group. + inline __device__ uint64_t get_descriptor(int desc_idx) const { + // printf("desc[0] = 0x%lx\n", desc[0]); + return desc[(Gmma_vector_size == Gmma_descriptor_size::ALL) ? desc_idx : 0]; + } + + // get the max descriptor for desc[0] + inline __device__ uint64_t get_max_descriptor_0() const { return max_desc_0; } + + // set a single desc from the desc group. + inline __device__ void set_descriptor(int desc_idx, uint64_t single_desc) { + desc[(Gmma_vector_size == Gmma_descriptor_size::ALL) ? desc_idx : 0] = single_desc; + } + + // set the max descriptor for desc[0]. Should be called once from prologue. + // Should be called with set_smem_pointer() + // This value is needed to "loop back" to the first LDGSTS buffer when appropriate. + inline __device__ void set_max_descriptor_0(int mem_offset_no_4LSB) { + max_desc_0 = desc[0] + mem_offset_no_4LSB; + } + + // for desc group where all desc all allocated, + // increment_single_descriptor() will do nothing. + inline __device__ void increment_single_descriptor(bool last_of_kblock) { + // update smem start address, which is in lower 32bits. + int2& tmp = reinterpret_cast(desc[0]); + if (last_of_kblock == true) { + tmp.x -= BYTES_DESC_INC_BOUNDARY_NO_4LSB; + } else { + tmp.x += BYTES_PER_DESC_NO_4LSB; + } + } + + template + inline __device__ void increment_single_descriptor() { + int2& tmp = reinterpret_cast(desc[0]); + tmp.x += (BYTE_OFFSET >> 4); + } + + private: + // the descriptors, each of 64 bit + uint64_t desc[NUM_DESCRIPTORS]; + // the max desc for desc_idx = 0 + uint64_t max_desc_0; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// for b +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class Gmma_descriptor_b { + public: + // The type of the Single Descriptor + using Single_desc = Single_descriptor; + + // Transpose mode. + static constexpr Gmma_descriptor_transpose TRANS_MODE = Gmma_trans; + + // The number of descriptors per 64xNblockxKblock. + static constexpr Gmma_descriptor_size GMMA_DESC_SIZE_PER_GROUP = Gmma_vector_size; + + // Currently the number of descriptors per 64xNblockxKblock is always One + // Historically we have supported more descriptors. But that has proven to + // be less performant as it consumes too many uniform registers. + // During the process of refactoring we have decided to only support allocating + // one desc per 64xNblockxKblock. If needed in the future, we can support + // more desc. + static_assert(Gmma_vector_size == Gmma_descriptor_size::ONE, + "Currently, only Mblock/64 desc is allocated per kgroup\n"); + + // Interleaved Mode is currently not supported. + // static_assert to avoid accidentally instantiate it. + static_assert(Gmma_mode != Gmma_descriptor_mode::SWIZZLE_NONE, + "Currently, SWIZZLE_NONE mode is not implemented. \n"); + + // byte per leading dim (column if TN, row if NT), must be 128 + enum { BYTES_PER_LEADING_DIM = 128 }; + + // bytes per element + enum { BYTES_PER_ELEMENT = BITS_PER_ELEMENT / 8 }; + + // the number of descriptors per kblock is related to GMMA shape and kblock size + enum { + NUM_DESCRIPTORS = (Gmma_vector_size == Gmma_descriptor_size::ALL) ? Cta_tile::K / GMMA_K : 1 + }; + + // the number of descriptors per 128 byte in k dimension (leading dim) + // NUM_DESCRIPTORS_PER_128B_IN_K is really only needed if leading dim is K + enum { + NUM_DESCRIPTORS_PER_128B_IN_K = (Gmma_mode == Gmma_descriptor_mode::SWIZZLE_128B && + Gmma_trans == Gmma_descriptor_transpose::NOTRANS) + ? BYTES_PER_LEADING_DIM / ((GMMA_K * BITS_PER_ELEMENT) / 8) + : NUM_DESCRIPTORS + }; + + static constexpr uint32_t BYTES_PER_GMMA_K = GMMA_K * BITS_PER_ELEMENT / 8; // 32B + + // the distance between neighboring descriptors + static constexpr uint32_t BYTES_PER_DESC = + Gmma_vector_size == Gmma_descriptor_size::ALL ? 0 + : Gmma_trans == Gmma_descriptor_transpose::TRANS + ? Gmma_mode == Gmma_descriptor_mode::SWIZZLE_128B ? GMMA_K * BYTES_PER_LEADING_DIM + : Gmma_mode == Gmma_descriptor_mode::SWIZZLE_64B ? (GMMA_K / 2) * BYTES_PER_LEADING_DIM + : Gmma_mode == Gmma_descriptor_mode::SWIZZLE_32B ? (GMMA_K / 4) * BYTES_PER_LEADING_DIM + : 0 + : Gmma_trans == Gmma_descriptor_transpose::NOTRANS + ? Gmma_mode == Gmma_descriptor_mode::SWIZZLE_128B || + Gmma_mode == Gmma_descriptor_mode::SWIZZLE_64B + ? BYTES_PER_GMMA_K // 32B + : Gmma_mode == Gmma_descriptor_mode::SWIZZLE_32B ? GMMA_N * BYTES_PER_GMMA_K + : 0 + : 0; + + // the distance between neighboring desc without 4LSB + static constexpr uint32_t BYTES_PER_DESC_NO_4LSB = BYTES_PER_DESC >> 4; + + // the distance to travel back from the last desc to the first desc within a group + enum { BYTES_DESC_INC_BOUNDARY_NO_4LSB = BYTES_PER_DESC_NO_4LSB * (Cta_tile::K / GMMA_K - 1) }; + + // Byte count on tile-K dimension + enum { + RESET_SMEM = ((Gmma_trans == Gmma_descriptor_transpose::NOTRANS) && + (((Cta_tile::K * BITS_PER_ELEMENT) / (8 * BYTES_PER_LEADING_DIM)) > 1)) + ? true + : false + }; + + // Reset bytes per BYTES_PER_LEADING_DIM (128) x tile-N + enum { RESET_BYTES_NO_4LSB = (BYTES_PER_LEADING_DIM * Cta_tile::N) / 16 }; + + // set GMMA descriptor mode bits. + static constexpr uint64_t DESCRIPTOR_MODE_IN_BIT_LOCATION = + (static_cast(Gmma_mode) & ((1u << GMMA_DESCRIPTOR_MODE_BITS) - 1)) + << GMMA_DESCRIPTOR_MODE_SHIFT; + + // stride byte offset, bit 32-45, 4LSB not included + // each column is always of 128 byte. 8 columns always. + // divide by 16 since the 4 LSB is not included + static constexpr uint64_t STRIDE_BYTE_OFFSET = + BYTES_PER_LEADING_DIM * + ((Gmma_mode == Gmma_descriptor_mode::SWIZZLE_128B) ? 8 + : Gmma_mode == Gmma_descriptor_mode::SWIZZLE_64B ? 4 + : 2) / + 16; + // shift 32 bit + static constexpr uint64_t STRIDE_BYTE_OFFSET_IN_BIT_LOCATION = STRIDE_BYTE_OFFSET << 32; + + // leading byte offset, bit 16-29, 4LSB not included + // each column is still 128 byte. + // divide by 16 since the 4 LSB is not included + // for B matrix of TN, and the way we reshape the matrix, LEADING_BYTE_OFFSET is never non-zero + // in the future with different GMMA shape, this might be needed + static constexpr bool LEADING_BYTE_OFFSET_NEEDED = + (((GMMA_N * BITS_PER_ELEMENT) / 8 > BYTES_PER_LEADING_DIM && + Gmma_trans == Gmma_descriptor_transpose::TRANS) || + GMMA_K == 64) + ? true + : false; + + // the leading byte offset if needed 4LSB not included + static constexpr uint64_t LEADING_BYTE_OFFSET = + GMMA_K == 64 + ? Cta_tile::N * 32 / 16 + : (BYTES_PER_LEADING_DIM * + ((Gmma_trans == Gmma_descriptor_transpose::TRANS) ? Cta_tile::K : Cta_tile::N) / 16); + // shift 16 bit + static constexpr uint64_t LEADING_BYTE_OFFSET_IN_BIT_LOCATION = + LEADING_BYTE_OFFSET_NEEDED ? LEADING_BYTE_OFFSET << 16 : 0; + + // ctor + inline __device__ Gmma_descriptor_b() { +#pragma unroll + for (int desc_idx = 0; desc_idx < NUM_DESCRIPTORS; ++desc_idx) { + desc[desc_idx] = 0; + } + +// set bit 62-63 to 1 for SWIZZLE_128B format +// set bit 62-63 to 2 for SWIZZLE_64B format +#pragma unroll + for (int desc_idx = 0; desc_idx < NUM_DESCRIPTORS; ++desc_idx) { + desc[desc_idx] |= DESCRIPTOR_MODE_IN_BIT_LOCATION; + } + +// stride byte offset, bit 32-45, 4LSB not included +#pragma unroll + for (int desc_idx = 0; desc_idx < NUM_DESCRIPTORS; ++desc_idx) { + desc[desc_idx] |= STRIDE_BYTE_OFFSET_IN_BIT_LOCATION; + } + + // leading byte offset, bit 16-29, 4LSB not included + if (LEADING_BYTE_OFFSET_NEEDED) { +#pragma unroll + for (int desc_idx = 0; desc_idx < NUM_DESCRIPTORS; ++desc_idx) { + desc[desc_idx] |= LEADING_BYTE_OFFSET_IN_BIT_LOCATION; + } + } + } + + // update the descriptor based on smem address. Should be called once from prologue. + inline __device__ void set_smem_pointer(uint32_t smem_nvvm_pointer) { + // uint64_t smem_address_bit = reinterpret_cast(smem); + // uint32_t smem_nvvm_pointer = get_smem_pointer(smem); + uint64_t smem_address_bit = static_cast(smem_nvvm_pointer); + + // set base offset, bit 49-61 + uint64_t offset = (smem_address_bit / BYTES_PER_LEADING_DIM) % + ((Gmma_mode == Gmma_descriptor_mode::SWIZZLE_128B) ? 8 + : Gmma_mode == Gmma_descriptor_mode::SWIZZLE_64B ? 4 + : 2); + uint64_t offset_in_bit_location = offset << 49; +#pragma unroll + for (int desc_idx = 0; desc_idx < NUM_DESCRIPTORS; ++desc_idx) { + desc[desc_idx] |= offset_in_bit_location; + } + +// start_address, bit 0-13, 4LSB not included(so grab bit 4-17) +// the only bits that is different for each desc of the same obj +#pragma unroll + for (int desc_idx = 0; desc_idx < NUM_DESCRIPTORS; ++desc_idx) { + // for fp16, desc_idx_in_128B should range from 0 to 3 + int desc_idx_in_128B = desc_idx % NUM_DESCRIPTORS_PER_128B_IN_K; + int desc_idx_over_128B = desc_idx / NUM_DESCRIPTORS_PER_128B_IN_K; + + uint64_t smem_address_bit_in_bit_location = + (smem_address_bit + ((GMMA_K * BITS_PER_ELEMENT) / 8) * desc_idx_in_128B + + Cta_tile::N * BYTES_PER_LEADING_DIM * desc_idx_over_128B) + << 46; + smem_address_bit_in_bit_location = smem_address_bit_in_bit_location >> 50; + desc[desc_idx] |= smem_address_bit_in_bit_location; + } + } + + // get a single desc from the desc group. + inline __device__ uint64_t get_descriptor(int desc_idx) const { + // if(threadIdx.x == 128) + // printf("desc[0] = 0x%lx\n", desc[0]); + //__syncwarp(); + return desc[(Gmma_vector_size == Gmma_descriptor_size::ALL) ? desc_idx : 0]; + } + + // get the max descriptor for desc[0] + inline __device__ uint64_t get_max_descriptor_0() const { return max_desc_0; } + + // set a single desc from the desc group. + inline __device__ void set_descriptor(int desc_idx, uint64_t single_desc) { + desc[(Gmma_vector_size == Gmma_descriptor_size::ALL) ? desc_idx : 0] = single_desc; + } + + // set the max descriptor for desc[0]. Should be called once from prologue. + // Should be called with set_smem_pointer() + // This value is needed to "loop back" to the first LDGSTS buffer when appropriate. + inline __device__ void set_max_descriptor_0(int mem_offset_no_4LSB) { + max_desc_0 = desc[0] + mem_offset_no_4LSB; + } + + // for desc group where all desc all allocated, + // increment_single_descriptor() will do nothing. + inline __device__ void increment_single_descriptor(bool last_of_kblock) { + // update smem start address, which is in lower 32bits. + int2& tmp = reinterpret_cast(desc[0]); + if (last_of_kblock == true) { + tmp.x -= BYTES_DESC_INC_BOUNDARY_NO_4LSB; + } else { + tmp.x += BYTES_PER_DESC_NO_4LSB; + } + } + + template + inline __device__ void increment_single_descriptor() { + int2& tmp = reinterpret_cast(desc[0]); + tmp.x += (BYTE_OFFSET >> 4); + } + + // for desc group where all desc all allocated, + // increment_single_descriptor() will do nothing. + inline __device__ void increment_single_descriptor(bool last_of_kblock, bool switch_kblock) { + // update smem start address, which is in lower 32bits. + int2& tmp = reinterpret_cast(desc[0]); + if (RESET_SMEM) { + if (switch_kblock) { + tmp.x -= BYTES_PER_DESC_NO_4LSB; + tmp.x += RESET_BYTES_NO_4LSB; + } else { + if (last_of_kblock == true) { + tmp.x -= BYTES_PER_DESC_NO_4LSB; + tmp.x -= RESET_BYTES_NO_4LSB; + } else { + tmp.x += BYTES_PER_DESC_NO_4LSB; + } + } + } else { + if (last_of_kblock == true) { + tmp.x -= BYTES_DESC_INC_BOUNDARY_NO_4LSB; + } else { + tmp.x += BYTES_PER_DESC_NO_4LSB; + } + } + } + + private: + // the descriptors, each of 64 bit + uint64_t desc[NUM_DESCRIPTORS]; + // the max desc for desc_idx = 0 + uint64_t max_desc_0; +}; + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/hopper/kernel_traits.h b/csrc/fmha_v2/fmha/hopper/kernel_traits.h new file mode 100644 index 0000000000..edeff1e281 --- /dev/null +++ b/csrc/fmha_v2/fmha/hopper/kernel_traits.h @@ -0,0 +1,365 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once +#include +#include +#include +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // Instruction traits. + typename Traits_p_, + // Instruction traits. + typename Traits_o_, + // The ldgsts global memory tile for Q, K and V. + template class Gmem_tile_qkv_, + // The tma global memory tile for Q, K and V. + template class Gmem_tile_tma_qkv_, + // The global memory tile for the output. + template class Gmem_tile_o_, + // Sequence length. + int S, + // The hidden dimension. + int D, + // The iteration step of the outer loop. + int STEP, + // The number of vertical warps. + int WARPS_M, + // The number of horizontal warps. + int WARPS_N, + // The version of the kernel. + int VERSION_, + // The mask version of the kernel, (2 denotes dense mask, 3 denotes causal mask) + int MASK_VERSION_ = 2, + // The flags to control the behaviour of LDGs. + uint32_t FLAGS = 0x8u> +struct FMHA_kernel_traits_hopper { + // The instruction traits for the Q*K product. + using Traits_p = Traits_p_; + + // is Q operand in RF for GMMA? + static constexpr bool GMMA_Q_RF = Traits_p::GMMA_A_RF; + + // is K operand in RF for GMMA? + static constexpr bool GMMA_K_RF = Traits_p::GMMA_B_RF; + + // The instruction traits for P*V product. + using Traits_o = Traits_o_; + + // is S operand in RF for GMMA? + static constexpr bool GMMA_S_RF = Traits_o::GMMA_A_RF; + + // is V operand in RF for GMMA? + static constexpr bool GMMA_V_RF = Traits_o::GMMA_B_RF; + + // The number of warpgroups along M dimension + enum { WARP_GROUP_M = WARPS_M / 4 }; + + // The number of warpgroups along N dimension + enum { WARP_GROUP_N = WARPS_N }; + + // The number of warpgroups along K dimension + enum { WARP_GROUP_K = 1 }; + + // The CTA description for the 1st GEMM. + using Cta_tile_p = + typename Traits_p::template Cta_tile; + // The CTA description for the 2nd GEMM. + using Cta_tile_o = + typename Traits_o::template Cta_tile; + + // The version. + enum { VERSION = VERSION_ }; + + enum { MASK_VERSION = MASK_VERSION_ }; + + // Whether use causal mask or not. + enum { CAUSAL_MASK = MASK_VERSION_ >= 3 }; + + // Whether use the sliding window attention mask or not. + enum { SLIDING_WINDOW_ATTENTION = MASK_VERSION_ == 4 }; + + // Do we use LDGSTS for Q, K or V. If not, TMA is used! + enum { USE_LDGSTS_Q = (FLAGS & 0x1u) != 0u }; + + enum { USE_LDGSTS_K = (FLAGS & 0x2u) != 0u }; + + enum { USE_LDGSTS_V = (FLAGS & 0x4u) != 0u }; + + enum { USE_TMA_Q = !USE_LDGSTS_Q }; + + enum { USE_TMA_K = !USE_LDGSTS_K }; + + enum { USE_TMA_V = !USE_LDGSTS_V }; + + // Do we use one buffer for K and V. + enum { SHARE_SMEM_FOR_K_AND_V = 0 }; + + // Do we use the scale max trick. + enum { USE_SCALE_MAX = 0 }; + + // Are heads in QKV interleaved, i.e. total x h x 3 x d or total x 3 x h x d. + enum { HEADS_INTERLEAVED = (FLAGS & 0x20u) == 0u }; + + // Use BMM1 softcapping scale or not. + enum { ENABLE_BMM1_SOFTCAPPING_SCALE = (FLAGS & 0x800) != 0u }; + + // Number of matrix for gmem_tile_qkv + enum { NUM_QKV_MATS = 3 }; + + // The global memory tile to load Q. + // Hopefully we don't need to specialize for Hopper. + using Gmem_tile_ldgsts_q = + Gmem_tile_qkv_; + + // The global memory tile to load Q with TMA. + using Gmem_tile_tma_q = Gmem_tile_tma_qkv_; + + // Do we use ldgsts gmem tile or tma gmem tile? + using Gmem_tile_q = + typename std::conditional_t; + + // 2 buffers for Q + enum { BUFFERS_PER_SMEM_TILE_Q = 2 }; + + // Q is row major + using Q_layout = fmha::Row; + + // We know Q is row-major. So we can also deduce the descriptor mode. + static constexpr fmha::Gmma_descriptor_mode GMMA_DESC_MODE_Q = + Cta_tile_p::K * sizeof(typename Traits_p::A_type) >= 128 + ? fmha::Gmma_descriptor_mode::SWIZZLE_128B + : fmha::Gmma_descriptor_mode::SWIZZLE_64B; + + // The shared memory tile to swizzle Q. + using Smem_tile_ldgsts_q = + fmha::Smem_tile_hopper_a; + + // The shared memory tile to swizzle Q. TODO: need to update to XMMA. + using Smem_tile_tma_q = + fmha::wip::Smem_tile_hopper_a; + + using Smem_tile_q = + typename std::conditional_t; + + // The global memory tile to load K. + // Hopefully we don't need to specialize for hopper. + using Gmem_tile_ldgsts_k = + Gmem_tile_qkv_; + + // The global memory tile to load K with TMA. + using Gmem_tile_tma_k = Gmem_tile_tma_qkv_; + + // Do we use ldgsts gmem tile or tma gmem tile? + using Gmem_tile_k = + typename std::conditional_t; + + // 1 buffers for K + enum { BUFFERS_PER_SMEM_TILE_K = 1 }; + + // K is column major + using K_layout = fmha::Col; + + // We know K is column-major. So we can also deduce the descriptor mode. + static constexpr fmha::Gmma_descriptor_mode GMMA_DESC_MODE_K = + Cta_tile_p::K * sizeof(typename Traits_p::B_type) >= 128 + ? fmha::Gmma_descriptor_mode::SWIZZLE_128B + : fmha::Gmma_descriptor_mode::SWIZZLE_64B; + + // The shared memory tile to swizzle K. + using Smem_tile_ldgsts_k = + fmha::Smem_tile_hopper_b; + + using Smem_tile_tma_k = + fmha::wip::Smem_tile_hopper_b; + + using Smem_tile_k = + typename std::conditional_t; + + // The global memory tile to load V. + using Gmem_tile_ldgsts_v = + Gmem_tile_qkv_; + + // The global memory tile to load V with TMA. + using Gmem_tile_tma_v = Gmem_tile_tma_qkv_; + + // Do we use ldgsts gmem tile or tma gmem tile? + using Gmem_tile_v = + typename std::conditional_t; + + // 1 buffers for V + enum { BUFFERS_PER_SMEM_TILE_V = 1 }; + + // V is row major + using V_layout = fmha::Row; + + // We know V is row marjor. So we can also deduce the descriptor mode. + static constexpr fmha::Gmma_descriptor_mode GMMA_DESC_MODE_V = + Cta_tile_o::N * sizeof(typename Traits_o::B_type) >= 128 + ? fmha::Gmma_descriptor_mode::SWIZZLE_128B + : fmha::Gmma_descriptor_mode::SWIZZLE_64B; + + // The shared memory tile to swizzle V. + using Smem_tile_ldgsts_v = fmha::Smem_tile_v; + + using Smem_tile_tma_v = + fmha::wip::Smem_tile_hopper_b; + + using Smem_tile_v = + typename std::conditional_t; + + // The global memory tile to store O. + // using Gmem_tile_o = fmha::Gmem_tile_o_hopper; + using Gmem_tile_o = fmha::v2::Gmem_tile_o; + + using Smem_tile_o_ = fmha::Smem_tile_o; + static constexpr bool NEEDS_SPLIT_K = WARPS_N > 1; + using Smem_tile_o = + typename std::conditional_t; + + // The amount of shared memory needed to load Q and K. + enum { BYTES_PER_SMEM_QK = Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE }; + + // The extra amount of shared memory needed to load V. + enum { BYTES_PER_SMEM_V = SHARE_SMEM_FOR_K_AND_V ? 0u : Smem_tile_v::BYTES_PER_TILE }; + + // The amount of shared memory needed for Q, K and V.. + enum { BYTES_PER_SMEM_QKV = BYTES_PER_SMEM_QK + BYTES_PER_SMEM_V }; + + // The amount of shared memory needed to load Q and store O. + // enum { BYTES_PER_SMEM_QO = Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE }; + // For now let's pretend no smem for O matrix. [Timmy] + enum { BYTES_PER_SMEM_QO = Smem_tile_q::BYTES_PER_TILE }; + + // The amount of over allocated smem to guarantee 1024B alignment. + enum { BYTES_FOR_ALIGNMENT = 1024 }; + + // The size in bytes for each SMEM barrier + enum { BYTES_PER_SMEM_BARRIER = 8 }; + + // The amount of smem used by smem barrier. Only needed if TMA is used. + enum { + BYTES_FOR_SMEM_BARRIER_Q = + USE_LDGSTS_Q == 1 ? 0 : BUFFERS_PER_SMEM_TILE_Q * BYTES_PER_SMEM_BARRIER + }; + + // The amount of smem used by smem barrier. Only needed if TMA is used. + // each smem barrier is 8 bytes, each buffer has 2 barriers + enum { + BYTES_FOR_SMEM_BARRIER_K = + USE_LDGSTS_K == 1 ? 0 : BUFFERS_PER_SMEM_TILE_K * BYTES_PER_SMEM_BARRIER + }; + + // The amount of smem used by smem barrier. Only needed if TMA is used. + // Currently, K and V can share the same barrier. + enum { BYTES_FOR_SMEM_BARRIER_V = 0 }; + + // The amount of smem used by smem barrier. Only needed if TMA is used. + enum { + BYTES_FOR_SMEM_BARRIER = + BYTES_FOR_SMEM_BARRIER_Q + BYTES_FOR_SMEM_BARRIER_K + BYTES_FOR_SMEM_BARRIER_V + }; + + // TODO move those + enum { BYTES_FOR_SOFTMAX = WARPS_N == 1 ? 0 : sizeof(float) * WARPS_N * 64 }; + + enum { + BYTES_PER_SMEM_O = + WARPS_N == 1 ? 0 : WARPS_N * 64 * D * sizeof(typename Traits_o::Epilogue_type) + }; + + static_assert(Smem_tile_o::BYTES_PER_TILE == (int)BYTES_PER_SMEM_O); + + // The amount of shared memory needed for Q, K, V and O. + // TODO double check. + // - For GMMA QKV are always stored in SMEM. + // - Cannot share SMEM K/V + // - O needs to be separate + // enum { BYTES_PER_SMEM = fmha::Max::VALUE + enum { + BYTES_PER_SMEM = BYTES_PER_SMEM_QKV + BYTES_PER_SMEM_O + BYTES_FOR_SOFTMAX + + BYTES_FOR_SMEM_BARRIER + BYTES_FOR_ALIGNMENT + }; + + // The number of threads. + enum { THREADS = Cta_tile_p::THREADS_PER_CTA }; + + // Make sure the number of threads matches both CTAs. + static_assert((int)THREADS == (int)Cta_tile_o::THREADS_PER_CTA, ""); + + // The compute tile for P = Q*K. + using Compute_tile_p = + fmha::Compute_tile_with_gmma; + // The compute tile for O = S*V. + using Compute_tile_o = + fmha::Compute_tile_with_gmma; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The BMM1 instruction traits. + typename Traits_p, + // The BMM2 instruction traits. + typename Traits_o, + // The sequence length. + int S, + // The hidden size per head. + int D, + // The number of timesteps per iteration of the main loop. + int STEP, + // The number of vertical warps. + int WARPS_M, + // The number of horizontal warps. + int WARPS_N, + // The attention mask type (2 denotes dense mask, 3 denotes causal mask). + int MASK_VERSION, + // The flags. + uint32_t FLAGS = 0x8> +using FMHA_kernel_traits_hopper_v2 = + FMHA_kernel_traits_hopper; + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/fmha_v2/fmha/hopper/smem_tile.h b/csrc/fmha_v2/fmha/hopper/smem_tile.h new file mode 100644 index 0000000000..b921b48db2 --- /dev/null +++ b/csrc/fmha_v2/fmha/hopper/smem_tile.h @@ -0,0 +1,2423 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// +/// @brief Interface to Smem tiles for a operator +// HGMMA +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +enum class Gmma_fusion_mode { NO_FUSION, BN_APPLY }; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace wip { + +template +struct Smem_tile_hopper_a {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_hopper_b {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// A Col Major. For GMMA, A is from SMEM directly. +// Not implemented, since it is not really needed at the moment. +template +struct Smem_tile_hopper_gmma_col_a {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// A Row Major. For GMMA, A is from SMEM directly. +template +struct Smem_tile_hopper_gmma_row_a { + // Currently Interleaved Mode is not implemented. + static_assert(desc_mode != fmha::Gmma_descriptor_mode::SWIZZLE_NONE, + "Currently, SWIZZLE_NONE Mode is not implemented.\n"); + + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of desc within a gmma group (kblock limited). + static constexpr fmha::Gmma_descriptor_size GMMA_DESC_SIZE_PER_GROUP = + fmha::Gmma_descriptor_size::ONE; + + // The SWIZZLE_128B descriptor. + using Gmma_descriptor = + fmha::Gmma_descriptor_a; + + using Cta_tile_gmma = Cta_tile; + + // the size in bits of each element. + enum { BITS_PER_ELEMENT = Traits::BITS_PER_ELEMENT_A }; + + // the size of bytes of each element. + enum { BYTES_PER_ELEMENT = BITS_PER_ELEMENT / 8 }; + + // The size in bytes of a single LDGSTS/STS. + enum { BYTES_PER_STS = 16 }; + + // The number of elements per LDGSTS/STS. + enum { ELEMENTS_PER_STS = BYTES_PER_STS * 8 / BITS_PER_ELEMENT }; + + // SMEM layout for GMMA has a leading dim of exact 128 Byte, at least for SWIZZLE_128B + // and SWIZZLE_64B format. + enum { BYTES_PER_ROW = 128 }; + + // the number of rows per one row of K due the the limitation of leading dim size. + enum { NUM_ROWS_PER_K = (Cta_tile::K * BYTES_PER_ELEMENT + BYTES_PER_ROW - 1) / BYTES_PER_ROW }; + + static_assert(desc_mode != fmha::Gmma_descriptor_mode::SWIZZLE_64B || + (Cta_tile::K * BYTES_PER_ELEMENT) == 64, + "swizzle_64B row_a is valid if kblock=32\n"); + + // Number of SMEM rows. + enum { + NUM_ROWS = (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_128B) + ? (Cta_tile::M * NUM_ROWS_PER_K) + : (Cta_tile::M / 2) + }; + + // The size of one buffer in bytes in shared memory. + enum { BYTES_PER_BUFFER = NUM_ROWS * BYTES_PER_ROW }; + + // the size of one buffer in bytes in shared memory, without the 4 LSB. + // this is needed to increment the GMMA desc to the next buffer. + enum { BYTES_PER_BUFFER_NO_4LSB = BYTES_PER_BUFFER / 16 }; + + // this is needed to decrement GMMA desc. + enum { + BYTES_PER_BUFFER_INC_BOUNDARY_NO_4LSB = + BYTES_PER_BUFFER_NO_4LSB * BUFFERS_PER_TILE_ - BYTES_PER_BUFFER_NO_4LSB + }; + + // The number of buffers. + enum { BUFFERS_PER_TILE = BUFFERS_PER_TILE_ }; + + // The size in bytes of total buffers. + enum { BYTES_PER_TILE = BYTES_PER_BUFFER * BUFFERS_PER_TILE }; + + // The boundary for smem_read_offset and smem_write_offset increment. + enum { BYTES_PER_TILE_INC_BOUNDARY = BYTES_PER_TILE - BYTES_PER_BUFFER }; + + // The number of threads needed to store a row + enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_STS }; + + // The number of rows written with a single STS. + enum { ROWS_PER_STS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + + // for swizzle_128B the xor factor is 8 + enum { ROWS_PER_XOR_PATTERN = (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_128B) ? 8 : 4 }; + + // The distance in byte between different GMMA groups (might need multiple due to cta tile size) + // each GMMA group is of size GMMA_M x GMMA_N x Kblock + enum { + GMMA_GROUP_SMEM_DISTANCE = Mma_tile::M_PER_GMMA_GROUP / + (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_128B ? 1 : 2) * + BYTES_PER_ROW + }; + + // The number of STS per row. + enum { STS_PER_ROW = BYTES_PER_ROW / THREADS_PER_ROW / BYTES_PER_STS }; + + // For Hopper, STS_PER_ROW should be 1 (at least for now.) + static_assert(STS_PER_ROW == 1, ""); + + // Ctor. + inline __device__ Smem_tile_hopper_gmma_row_a(char* smem, int tidx) + : smem_(__nvvm_get_smem_pointer(smem)) { + int smem_write_row = tidx / THREADS_PER_ROW; + int smem_write_xor = smem_write_row % ROWS_PER_XOR_PATTERN; + int smem_write_col = 0; + + if (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_128B) { + smem_write_col = (tidx % THREADS_PER_ROW) ^ smem_write_xor; + } else if (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_64B) { + smem_write_col = (tidx % (THREADS_PER_ROW / 2)) ^ + smem_write_xor + ((tidx % THREADS_PER_ROW) / (THREADS_PER_ROW / 2)) * 4; + } + + this->smem_write_offset_ = smem_write_row * BYTES_PER_ROW + smem_write_col * BYTES_PER_STS; + + // That code is expected to trigger the utilization of the URF by the compiler. + this->smem_read_buffer_ = __shfl_sync(0xffffffff, 0, 0); + this->smem_write_buffer_ = __shfl_sync(0xffffffff, 0, 0); + } + + // Compute the store pointers. + template + inline __device__ void compute_store_pointers(uint32_t (&ptrs)[N]) { +#pragma unroll + for (int ii = 0; ii < N; ++ii) { + // Decompose the STS into row/col. + int row = ii / STS_PER_ROW; + // Assemble the offset. + int offset = smem_write_offset_ + row * ROWS_PER_STS * BYTES_PER_ROW; + // Assemble the final pointer :) + ptrs[ii] = smem_ + offset + smem_write_buffer_; + } + } + + // Store the tile in the shared memory. + template + inline __device__ void store(void const* (&gmem_ptrs)[N], uint32_t (&preds)[M], uint64_t = 0) { + uint32_t smem_ptrs[N]; + this->compute_store_pointers(smem_ptrs); + ldgsts(smem_ptrs, gmem_ptrs, preds); + } + + // Move the write offset to next buffer. + inline __device__ void move_next_write_buffer() { + if (BUFFERS_PER_TILE > 1) { + this->smem_write_offset_ += (smem_write_offset_ >= BYTES_PER_TILE_INC_BOUNDARY) + ? -BYTES_PER_TILE_INC_BOUNDARY + : BYTES_PER_BUFFER; + } + } + + inline __device__ void move_next_write_buffer(int) {} + + // Move the read offset to next buffer. + // do nothing, as it is controlled by gmma desc + inline __device__ void move_next_read_buffer() {} + + // The shared memory pointer. + uint32_t smem_; + // The read offset. Reserve 4 offsets if needed. + int smem_read_offset_; + // The write offset. + int smem_write_offset_; + // The buffer base offset for read. + int smem_read_buffer_; + // The buffer base offset for write. + int smem_write_buffer_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// B Col Major. For GMMA, B is from SMEM directly. +template +struct Smem_tile_hopper_gmma_col_b { + // Currently Interleaved Mode is not implemented. + static_assert(desc_mode != fmha::Gmma_descriptor_mode::SWIZZLE_NONE, + "Currently, Interleaved Mode is not implemented.\n"); + + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of desc within a gmma group (kblock limited) + static constexpr fmha::Gmma_descriptor_size GMMA_DESC_SIZE_PER_GROUP = + fmha::Gmma_descriptor_size::ONE; + + // The SWIZZLE_128B descriptor + using Gmma_descriptor = + fmha::Gmma_descriptor_b; + + using Cta_tile_gmma = Cta_tile; + + // the size in bits of each element. + enum { BITS_PER_ELEMENT = Traits::BITS_PER_ELEMENT_B }; + + // the size of bytes of each element. + enum { BYTES_PER_ELEMENT = BITS_PER_ELEMENT / 8 }; + + // The size in bytes of a single LDGSTS/STS. + enum { BYTES_PER_STS = 16 }; + + // The number of elements per LDGSTS/STS. + enum { ELEMENTS_PER_STS = BYTES_PER_STS * 8 / BITS_PER_ELEMENT }; + + // SMEM layout for GMMA has a leading dim of exact 128 Byte, at least for SWIZZLE_128B and + // SWIZZLE_64B format + enum { BYTES_PER_COLUMN = 128 }; + + static_assert(desc_mode != fmha::Gmma_descriptor_mode::SWIZZLE_64B || + (Cta_tile::K * BYTES_PER_ELEMENT) == 64, + "swizzle_64B col_b is valid if kblock=32\n"); + + // the number of columns per one column of K due the the limitation of leading dim size + enum { + NUM_COLS_PER_K = (Cta_tile::K * BYTES_PER_ELEMENT + BYTES_PER_COLUMN - 1) / BYTES_PER_COLUMN + }; + + // Number of SMEM columns. + enum { + NUM_COLUMNS = (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_128B) + ? Cta_tile::N * NUM_COLS_PER_K + : Cta_tile::N / 2 + }; + + // The size of one buffer in bytes in shared memory. + enum { BYTES_PER_BUFFER = NUM_COLUMNS * BYTES_PER_COLUMN }; + + // the size of one buffer in bytes in shared memory, without the 4 LSB. + // this is needed to increment the GMMA desc to the next buffer + enum { BYTES_PER_BUFFER_NO_4LSB = BYTES_PER_BUFFER / 16 }; + + // this is needed to decrement GMMA desc. + enum { + BYTES_PER_BUFFER_INC_BOUNDARY_NO_4LSB = + BYTES_PER_BUFFER_NO_4LSB * BUFFERS_PER_TILE_ - BYTES_PER_BUFFER_NO_4LSB + }; + + // The number of buffers. + enum { BUFFERS_PER_TILE = BUFFERS_PER_TILE_ }; + + // The size in bytes of total buffers. + enum { BYTES_PER_TILE = BYTES_PER_BUFFER * BUFFERS_PER_TILE }; + + // The boundary for smem_read_offset and smem_write_offset increment. + enum { BYTES_PER_TILE_INC_BOUNDARY = BYTES_PER_TILE - BYTES_PER_BUFFER }; + + // The number of threads needed to store a column. + enum { THREADS_PER_COLUMN = BYTES_PER_COLUMN / BYTES_PER_STS }; + + // The number of columns written with a single STS. + enum { COLUMNS_PER_STS = Cta_tile::THREADS_PER_CTA / THREADS_PER_COLUMN }; + + // for swizzle_128B the xor factor is 8. + enum { + COLUMNS_PER_XOR_PATTERN = (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_128B) ? 8 : 4 + }; + + // The distance in byte between different GMMA groups (might need multiple due to cta tile size) + // each GMMA group is of size GMMA_M x GMMA_N x Kblock + enum { + GMMA_GROUP_SMEM_DISTANCE = Mma_tile::N_PER_GMMA_GROUP / + (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_128B ? 1 : 2) * + BYTES_PER_COLUMN + }; + + // The number of STS per column. + enum { STS_PER_COLUMN = BYTES_PER_COLUMN / THREADS_PER_COLUMN / BYTES_PER_STS }; + + // For Hopper, STS_PER_COLUMN should be 1 (at least for now.) + static_assert(STS_PER_COLUMN == 1, ""); + + // Ctor. + inline __device__ Smem_tile_hopper_gmma_col_b(char* smem, int tidx) + : smem_(__nvvm_get_smem_pointer(smem)) { + int smem_write_col = tidx / THREADS_PER_COLUMN; + int smem_write_xor = smem_write_col % COLUMNS_PER_XOR_PATTERN; + int smem_write_row = 0; + + if (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_128B) { + smem_write_row = (tidx % THREADS_PER_COLUMN) ^ smem_write_xor; + } else if (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_64B) { + smem_write_row = + (tidx % (THREADS_PER_COLUMN / 2)) ^ + smem_write_xor + ((tidx % THREADS_PER_COLUMN) / (THREADS_PER_COLUMN / 2)) * 4; + } + + this->smem_write_offset_ = smem_write_col * BYTES_PER_COLUMN + smem_write_row * BYTES_PER_STS; + // That code is expected to trigger the utilization of the URF by the compiler. + this->smem_read_buffer_ = __shfl_sync(0xffffffff, 0, 0); + this->smem_write_buffer_ = __shfl_sync(0xffffffff, 0, 0); + } + + // Compute the store pointers. + template + inline __device__ void compute_store_pointers(uint32_t (&ptrs)[N]) { +#pragma unroll + for (int ii = 0; ii < N; ++ii) { + // Decompose the STS into row/col. + int col = ii / STS_PER_COLUMN; + // Assemble the offset. + int offset = smem_write_offset_ + col * COLUMNS_PER_STS * BYTES_PER_COLUMN; + // Assemble the final pointer :) + ptrs[ii] = smem_ + offset + smem_write_buffer_; + } + } + + // Store the tile in the shared memory. + template + inline __device__ void store(void const* (&gmem_ptrs)[N], uint32_t (&preds)[M], uint64_t = 0) { + uint32_t smem_ptrs[N]; + this->compute_store_pointers(smem_ptrs); + ldgsts(smem_ptrs, gmem_ptrs, preds); + } + + // Move the write offset to next buffer. + inline __device__ void move_next_write_buffer() { + // if( BUFFERS_PER_TILE > 1 ) { + // this->smem_write_offset_ += ( smem_write_offset_ >= BYTES_PER_TILE_INC_BOUNDARY ) + // ? -BYTES_PER_TILE_INC_BOUNDARY + // : BYTES_PER_BUFFER; + // } + } + + inline __device__ void move_next_write_buffer(int) {} + + // Move the read offset to next buffer. + // do nothing, as it is controlled by gmma desc + inline __device__ void move_next_read_buffer() {} + + // The shared memory pointer. + uint32_t smem_; + // The read offset. Reserve 4 offsets if needed. + int smem_read_offset_; + // The write offset. + int smem_write_offset_; + // The buffer base offset for read. + int smem_read_buffer_; + // The buffer base offset for write. + int smem_write_buffer_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// B Row Major. For GMMA, B is from SMEM directly. +template +struct Smem_tile_hopper_gmma_row_b { + // Currently Interleaved Mode is not implemented. + static_assert(desc_mode != fmha::Gmma_descriptor_mode::SWIZZLE_NONE, + "Currently, Interleaved Mode is not implemented.\n"); + + // For SWIZZLE_64B, row b is not needed/implemented + static_assert(desc_mode != fmha::Gmma_descriptor_mode::SWIZZLE_64B, + "Currently, for SWIZZLE_64B mode, row_b is not needed/implemented. \n"); + + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of desc within a gmma group (kblock limited) + static constexpr fmha::Gmma_descriptor_size GMMA_DESC_SIZE_PER_GROUP = + fmha::Gmma_descriptor_size::ONE; + + // The SWIZZLE_128B descriptor + using Gmma_descriptor = + fmha::Gmma_descriptor_b; + + using Cta_tile_gmma = Cta_tile; + + // the size in bits of each element. + enum { BITS_PER_ELEMENT = Traits::BITS_PER_ELEMENT_B }; + + // the size of bytes of each element. + enum { BYTES_PER_ELEMENT = BITS_PER_ELEMENT / 8 }; + + // The size in bytes of a single LDGSTS/STS. + enum { BYTES_PER_STS = 16 }; + + // The number of elements per LDGSTS/STS. + enum { ELEMENTS_PER_STS = BYTES_PER_STS * 8 / BITS_PER_ELEMENT }; + + // SMEM layout for GMMA has a leading dim of exact 128 Byte, at least for SWIZZLE_128B and + // SWIZZLE_64B format + enum { BYTES_PER_ROW = 128 }; + + // the number of rows per one row of N due the the limitation of leading dim size + enum { NUM_ROWS_PER_N = (Cta_tile::N * BYTES_PER_ELEMENT + BYTES_PER_ROW - 1) / BYTES_PER_ROW }; + + // the number of rows per one row of N_PER_GMMA_GROUP + enum { + NUM_ROWS_PER_GMMA_GROUP_N = + (Mma_tile::N_PER_GMMA_GROUP * BYTES_PER_ELEMENT + BYTES_PER_ROW - 1) / BYTES_PER_ROW + }; + + // Number of SMEM rows + enum { NUM_ROWS = Cta_tile::K * NUM_ROWS_PER_N }; + + // The size of one buffer in bytes in shared memory. + enum { BYTES_PER_BUFFER = NUM_ROWS * BYTES_PER_ROW }; + + // the size of one buffer in bytes in shared memory, without the 4 LSB. + // this is needed to increment the GMMA desc to the next buffer + enum { BYTES_PER_BUFFER_NO_4LSB = BYTES_PER_BUFFER / 16 }; + + // this is needed to decrement GMMA desc + enum { + BYTES_PER_BUFFER_INC_BOUNDARY_NO_4LSB = + BYTES_PER_BUFFER_NO_4LSB * BUFFERS_PER_TILE_ - BYTES_PER_BUFFER_NO_4LSB + }; + + // The number of buffers. + enum { BUFFERS_PER_TILE = BUFFERS_PER_TILE_ }; + + // The size in bytes of total buffers. + enum { BYTES_PER_TILE = BYTES_PER_BUFFER * BUFFERS_PER_TILE }; + + // The boundary for smem_read_offset and smem_write_offset increment. + enum { BYTES_PER_TILE_INC_BOUNDARY = BYTES_PER_TILE - BYTES_PER_BUFFER }; + + // The number of threads needed to store a row + enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_STS }; + + // The number of rows written with a single STS. + enum { ROWS_PER_STS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + + // for swizzle_128B the xor factor is 8 + enum { ROWS_PER_XOR_PATTERN = 8 }; + + // The distance in byte between different GMMA groups (might need multiple due to cta tile size) + // each GMMA group is of size GMMA_M x GMMA_N x Kblock + enum { + GMMA_GROUP_SMEM_DISTANCE = + Mma_tile::K_PER_GMMA_GROUP * NUM_ROWS_PER_GMMA_GROUP_N * BYTES_PER_ROW + }; + + // The number of STS per ROW. + enum { STS_PER_ROW = BYTES_PER_ROW / THREADS_PER_ROW / BYTES_PER_STS }; + + // For Hopper, STS_PER_ROW should be 1 (at least for now.) + static_assert(STS_PER_ROW == 1, ""); + + // Ctor. + inline __device__ Smem_tile_hopper_gmma_row_b(char* smem, int tidx) + : smem_(__nvvm_get_smem_pointer(smem)) { + int smem_write_row = tidx / THREADS_PER_ROW; + int smem_write_xor = smem_write_row % ROWS_PER_XOR_PATTERN; + int smem_write_col = (tidx % THREADS_PER_ROW) ^ smem_write_xor; + this->smem_write_offset_ = smem_write_row * BYTES_PER_ROW + smem_write_col * BYTES_PER_STS; + // That code is expected to trigger the utilization of the URF by the compiler. + this->smem_read_buffer_ = __shfl_sync(0xffffffff, 0, 0); + this->smem_write_buffer_ = __shfl_sync(0xffffffff, 0, 0); + } + + // Compute the store pointers. + template + inline __device__ void compute_store_pointers(uint32_t (&ptrs)[N]) { +#pragma unroll + for (int ii = 0; ii < N; ++ii) { + // Decompose the STS into row/col. + int row = ii / STS_PER_ROW; + // Assemble the offset. + int offset = smem_write_offset_ + row * ROWS_PER_STS * BYTES_PER_ROW; + + // Assemble the final pointer :) + ptrs[ii] = smem_ + offset + smem_write_buffer_; + } + } + + template + inline __device__ void store(void const* (&gmem_ptrs)[N], uint32_t (&preds)[M], uint64_t = 0) { + uint32_t smem_ptrs[N]; + this->compute_store_pointers(smem_ptrs); + ldgsts(smem_ptrs, gmem_ptrs, preds); + } + + // Move the write offset to next buffer. + inline __device__ void move_next_write_buffer() { + // if( BUFFERS_PER_TILE > 1 ) { + // this->smem_write_offset_ += ( smem_write_offset_ >= BYTES_PER_TILE_INC_BOUNDARY ) + // ? -BYTES_PER_TILE_INC_BOUNDARY + // : BYTES_PER_BUFFER; + // } + } + + inline __device__ void move_next_write_buffer(int) {} + + // Move the read offset to next buffer. + // do nothing, as it is controlled by gmma desc + inline __device__ void move_next_read_buffer() {} + + // The shared memory pointer. + uint32_t smem_; + // The read offset. Reserve 4 offsets if needed. + int smem_read_offset_; + // The write offset. + int smem_write_offset_; + // The buffer base offset for read. + int smem_read_buffer_; + // The buffer base offset for write. + int smem_write_buffer_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Specialized Interface +// LDGSTS smem tiles. +//////////////////////////////////////////////////////////////////////////////////////////////////// +// A Col Major, A coming from SMEM +template +struct Smem_tile_hopper_a + : public Smem_tile_hopper_gmma_col_a { + // The base class. + using Base = Smem_tile_hopper_gmma_col_a; + + // Ctor. + // comment the implementation out as a mark that this is not supported, yet. + // inline __device__ Smem_tile_hopper_a( char *smem, int tidx ) : Base( smem, tidx ) { + //} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// A Row Major, A coming from SMEM +template +struct Smem_tile_hopper_a + : public Smem_tile_hopper_gmma_row_a { + // The base class. + using Base = Smem_tile_hopper_gmma_row_a; + + // Ctor. + inline __device__ Smem_tile_hopper_a(char* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// B Col Major, B coming from SMEM +template +struct Smem_tile_hopper_b + : public Smem_tile_hopper_gmma_col_b { + // The base class. + using Base = Smem_tile_hopper_gmma_col_b; + + // Ctor. + inline __device__ Smem_tile_hopper_b(char* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// B Row Major, B coming from SMEM +template +struct Smem_tile_hopper_b + : public Smem_tile_hopper_gmma_row_b { + // The base class. + using Base = Smem_tile_hopper_gmma_row_b; + + // Ctor. + inline __device__ Smem_tile_hopper_b(char* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Specialized Interface +// TMA smem tiles. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// A Row Major. For GMMA, A is from SMEM directly. +template +struct Smem_tile_hopper_gmma_tma_row_a { + // Currently Interleaved Mode is not implemented. + static_assert(desc_mode != fmha::Gmma_descriptor_mode::SWIZZLE_NONE, + "Currently, SWIZZLE_NONE Mode is not implemented.\n"); + + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of desc within a gmma group (kblock limited). + static constexpr fmha::Gmma_descriptor_size GMMA_DESC_SIZE_PER_GROUP = + fmha::Gmma_descriptor_size::ONE; + + // The SWIZZLE_128B descriptor. + using Gmma_descriptor = + fmha::Gmma_descriptor_a; + + using Cta_tile_gmma = Cta_tile; + + // the size in bits of each element. + enum { BITS_PER_ELEMENT = Traits::BITS_PER_ELEMENT_A }; + + // the size of bytes of each element. + enum { BYTES_PER_ELEMENT = BITS_PER_ELEMENT / 8 }; + + // The size in bytes of a single LDGSTS/STS. + enum { BYTES_PER_STS = 16 }; + + // The number of elements per LDGSTS/STS. + enum { ELEMENTS_PER_STS = BYTES_PER_STS * 8 / BITS_PER_ELEMENT }; + + // SMEM layout for GMMA has a leading dim of exact 128 Byte, at least for SWIZZLE_128B + // and SWIZZLE_64B format. + enum { BYTES_PER_ROW = 128 }; + + // the number of rows per one row of K due the the limitation of leading dim size. + enum { NUM_ROWS_PER_K = (Cta_tile::K * BYTES_PER_ELEMENT + BYTES_PER_ROW - 1) / BYTES_PER_ROW }; + + static_assert(desc_mode != fmha::Gmma_descriptor_mode::SWIZZLE_64B || + (Cta_tile::K * BYTES_PER_ELEMENT) == 64, + "swizzle_64B row_a is valid if kblock=32\n"); + + // Number of SMEM rows. + enum { + NUM_ROWS = (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_128B) + ? (Cta_tile::M * NUM_ROWS_PER_K) + : (Cta_tile::M / 2) + }; + + // The size of one buffer in bytes in shared memory. + enum { BYTES_PER_BUFFER = NUM_ROWS * BYTES_PER_ROW }; + + // the size of one buffer in bytes in shared memory, without the 4 LSB. + // this is needed to increment the GMMA desc to the next buffer. + enum { BYTES_PER_BUFFER_NO_4LSB = BYTES_PER_BUFFER / 16 }; + + // this is needed to decrement GMMA desc. + enum { + BYTES_PER_BUFFER_INC_BOUNDARY_NO_4LSB = + BYTES_PER_BUFFER_NO_4LSB * BUFFERS_PER_TILE_ - BYTES_PER_BUFFER_NO_4LSB + }; + + // The number of buffers. + enum { BUFFERS_PER_TILE = BUFFERS_PER_TILE_ }; + + // The size in bytes of total buffers. + enum { BYTES_PER_TILE = BYTES_PER_BUFFER * BUFFERS_PER_TILE }; + + // The boundary for smem_read_offset and smem_write_offset increment. + enum { BYTES_PER_TILE_INC_BOUNDARY = BYTES_PER_TILE - BYTES_PER_BUFFER }; + + // The number of threads needed to store a row + enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_STS }; + + // The number of rows written with a single STS. + enum { ROWS_PER_STS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + + // for swizzle_128B the xor factor is 8 + enum { ROWS_PER_XOR_PATTERN = (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_128B) ? 8 : 4 }; + + // The distance in byte between different GMMA groups (might need multiple due to cta tile size) + // each GMMA group is of size GMMA_M x GMMA_N x Kblock + enum { + GMMA_GROUP_SMEM_DISTANCE = Mma_tile::M_PER_GMMA_GROUP / + (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_128B ? 1 : 2) * + BYTES_PER_ROW + }; + + // The number of STS per row. + enum { STS_PER_ROW = BYTES_PER_ROW / THREADS_PER_ROW / BYTES_PER_STS }; + + // For Hopper, STS_PER_ROW should be 1 (at least for now.) + static_assert(STS_PER_ROW == 1, ""); + + // Each smem barrier is of 8 bytes + enum { BYTES_PER_SMEM_BARRIER = 8 }; + + // The boundary for smem_read_offset and smem_write_offset increment. + enum { + BYTES_PER_TILE_INC_BOUNDARY_SMEM_BARRIER = + BYTES_PER_SMEM_BARRIER * BUFFERS_PER_TILE - BYTES_PER_SMEM_BARRIER + }; + + // Ctor. + inline __device__ Smem_tile_hopper_gmma_tma_row_a(char* smem, char* smem_barrier) + : smem_(__nvvm_get_smem_pointer(smem)), + smem_barrier_(__nvvm_get_smem_pointer(smem_barrier)), + smem_write_offset_(0), + smem_barrier_offset_(0) {} + + // Move the write offset to next buffer. + // Also move the smem_barrier. + inline __device__ void move_next_write_buffer() { + if (BUFFERS_PER_TILE > 1) { + this->smem_write_offset_ += (smem_write_offset_ >= BYTES_PER_TILE_INC_BOUNDARY) + ? -BYTES_PER_TILE_INC_BOUNDARY + : BYTES_PER_BUFFER; + } + + // also update the smem_barrier. + if (BUFFERS_PER_TILE > 1) { + this->smem_barrier_offset_ += + (smem_barrier_offset_ >= BYTES_PER_TILE_INC_BOUNDARY_SMEM_BARRIER) + ? -BYTES_PER_TILE_INC_BOUNDARY_SMEM_BARRIER + : BYTES_PER_SMEM_BARRIER; + } + } + + inline __device__ void move_next_write_buffer(int) {} + + // Move the read offset to next buffer. + // do nothing, as it is controlled by gmma desc + inline __device__ void move_next_read_buffer() {} + + template + inline __device__ void store(cudaTmaDesc const* p_desc, int32_t const (&coord)[DIM], + uint16_t filter_offsets = 0, uint16_t mcast_cta_mask = 0) { + fmha::utmaldg(p_desc, smem_ + smem_write_offset_, + smem_barrier_ + smem_barrier_offset_, coord); + } + + // The shared memory pointer. + uint32_t smem_; + // The barrier in smem. + uint32_t smem_barrier_; + // The write offset. + int smem_write_offset_; + // The smem barrier offset + int smem_barrier_offset_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// B Col Major. For GMMA, B is from SMEM directly. +template +struct Smem_tile_hopper_gmma_tma_col_b { + // Currently Interleaved Mode is not implemented. + static_assert(desc_mode != fmha::Gmma_descriptor_mode::SWIZZLE_NONE, + "Currently, Interleaved Mode is not implemented.\n"); + + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of desc within a gmma group (kblock limited) + static constexpr fmha::Gmma_descriptor_size GMMA_DESC_SIZE_PER_GROUP = + fmha::Gmma_descriptor_size::ONE; + + // The SWIZZLE_128B descriptor + using Gmma_descriptor = + fmha::Gmma_descriptor_b; + + using Cta_tile_gmma = Cta_tile; + + // the size in bits of each element. + enum { BITS_PER_ELEMENT = Traits::BITS_PER_ELEMENT_B }; + + // the size of bytes of each element. + enum { BYTES_PER_ELEMENT = BITS_PER_ELEMENT / 8 }; + + // The size in bytes of a single LDGSTS/STS. + enum { BYTES_PER_STS = 16 }; + + // The number of elements per LDGSTS/STS. + enum { ELEMENTS_PER_STS = BYTES_PER_STS * 8 / BITS_PER_ELEMENT }; + + // SMEM layout for GMMA has a leading dim of exact 128 Byte, at least for SWIZZLE_128B and + // SWIZZLE_64B format + enum { BYTES_PER_COLUMN = 128 }; + + static_assert(desc_mode != fmha::Gmma_descriptor_mode::SWIZZLE_64B || + (Cta_tile::K * BYTES_PER_ELEMENT) == 64, + "swizzle_64B col_b is valid if kblock=32\n"); + + // the number of columns per one column of K due the the limitation of leading dim size + enum { + NUM_COLS_PER_K = (Cta_tile::K * BYTES_PER_ELEMENT + BYTES_PER_COLUMN - 1) / BYTES_PER_COLUMN + }; + + // Number of SMEM columns. + enum { + NUM_COLUMNS = (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_128B) + ? Cta_tile::N * NUM_COLS_PER_K + : Cta_tile::N / 2 + }; + + // The size of one buffer in bytes in shared memory. + enum { BYTES_PER_BUFFER = NUM_COLUMNS * BYTES_PER_COLUMN }; + + // the size of one buffer in bytes in shared memory, without the 4 LSB. + // this is needed to increment the GMMA desc to the next buffer + enum { BYTES_PER_BUFFER_NO_4LSB = BYTES_PER_BUFFER / 16 }; + + // this is needed to decrement GMMA desc. + enum { + BYTES_PER_BUFFER_INC_BOUNDARY_NO_4LSB = + BYTES_PER_BUFFER_NO_4LSB * BUFFERS_PER_TILE_ - BYTES_PER_BUFFER_NO_4LSB + }; + + // The number of buffers. + enum { BUFFERS_PER_TILE = BUFFERS_PER_TILE_ }; + + // The size in bytes of total buffers. + enum { BYTES_PER_TILE = BYTES_PER_BUFFER * BUFFERS_PER_TILE }; + + // The boundary for smem_read_offset and smem_write_offset increment. + enum { BYTES_PER_TILE_INC_BOUNDARY = BYTES_PER_TILE - BYTES_PER_BUFFER }; + + // The number of threads needed to store a column. + enum { THREADS_PER_COLUMN = BYTES_PER_COLUMN / BYTES_PER_STS }; + + // The number of columns written with a single STS. + enum { COLUMNS_PER_STS = Cta_tile::THREADS_PER_CTA / THREADS_PER_COLUMN }; + + // for swizzle_128B the xor factor is 8. + enum { + COLUMNS_PER_XOR_PATTERN = (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_128B) ? 8 : 4 + }; + + // The distance in byte between different GMMA groups (might need multiple due to cta tile size) + // each GMMA group is of size GMMA_M x GMMA_N x Kblock + enum { + GMMA_GROUP_SMEM_DISTANCE = Mma_tile::N_PER_GMMA_GROUP / + (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_128B ? 1 : 2) * + BYTES_PER_COLUMN + }; + + // The number of STS per column. + enum { STS_PER_COLUMN = BYTES_PER_COLUMN / THREADS_PER_COLUMN / BYTES_PER_STS }; + + // For Hopper, STS_PER_COLUMN should be 1 (at least for now.) + static_assert(STS_PER_COLUMN == 1, ""); + + // Ctor. + inline __device__ Smem_tile_hopper_gmma_tma_col_b(char* smem, char* smem_barrier) + : smem_(__nvvm_get_smem_pointer(smem)), + smem_barrier_(__nvvm_get_smem_pointer(smem_barrier)) {} + + // Move the write offset to next buffer. + // Not implemented as it is not needed currently. + inline __device__ void move_next_write_buffer() {} + + inline __device__ void move_next_write_buffer(int) {} + + // Move the read offset to next buffer. + // do nothing, as it is controlled by gmma desc + inline __device__ void move_next_read_buffer() {} + + template + inline __device__ void store(cudaTmaDesc const* p_desc, int32_t const (&coord)[DIM], + uint16_t filter_offsets = 0, uint16_t mcast_cta_mask = 0) { + fmha::utmaldg(p_desc, smem_, smem_barrier_, coord); + } + + // The shared memory pointer. + uint32_t smem_; + // The barrier in smem. + uint32_t smem_barrier_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// B Row Major. For GMMA, B is from SMEM directly. +template +struct Smem_tile_hopper_gmma_tma_row_b { + // Currently Interleaved Mode is not implemented. + static_assert(desc_mode != fmha::Gmma_descriptor_mode::SWIZZLE_NONE, + "Currently, Interleaved Mode is not implemented.\n"); + + // For SWIZZLE_64B, row b is not needed/implemented + static_assert(desc_mode != fmha::Gmma_descriptor_mode::SWIZZLE_64B, + "Currently, for SWIZZLE_64B mode, row_b is not needed/implemented. \n"); + + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of desc within a gmma group (kblock limited) + static constexpr fmha::Gmma_descriptor_size GMMA_DESC_SIZE_PER_GROUP = + fmha::Gmma_descriptor_size::ONE; + + // The SWIZZLE_128B descriptor + using Gmma_descriptor = + fmha::Gmma_descriptor_b; + + using Cta_tile_gmma = Cta_tile; + + // the size in bits of each element. + enum { BITS_PER_ELEMENT = Traits::BITS_PER_ELEMENT_B }; + + // the size of bytes of each element. + enum { BYTES_PER_ELEMENT = BITS_PER_ELEMENT / 8 }; + + // The size in bytes of a single LDGSTS/STS. + enum { BYTES_PER_STS = 16 }; + + // The number of elements per LDGSTS/STS. + enum { ELEMENTS_PER_STS = BYTES_PER_STS * 8 / BITS_PER_ELEMENT }; + + // SMEM layout for GMMA has a leading dim of exact 128 Byte, at least for SWIZZLE_128B and + // SWIZZLE_64B format + enum { BYTES_PER_ROW = 128 }; + + // the number of rows per one row of N due the the limitation of leading dim size + enum { NUM_ROWS_PER_N = (Cta_tile::N * BYTES_PER_ELEMENT + BYTES_PER_ROW - 1) / BYTES_PER_ROW }; + + // the number of rows per one row of N_PER_GMMA_GROUP + enum { + NUM_ROWS_PER_GMMA_GROUP_N = + (Mma_tile::N_PER_GMMA_GROUP * BYTES_PER_ELEMENT + BYTES_PER_ROW - 1) / BYTES_PER_ROW + }; + + // Number of SMEM rows + enum { NUM_ROWS = Cta_tile::K * NUM_ROWS_PER_N }; + + // The size of one buffer in bytes in shared memory. + enum { BYTES_PER_BUFFER = NUM_ROWS * BYTES_PER_ROW }; + + // the size of one buffer in bytes in shared memory, without the 4 LSB. + // this is needed to increment the GMMA desc to the next buffer + enum { BYTES_PER_BUFFER_NO_4LSB = BYTES_PER_BUFFER / 16 }; + + // this is needed to decrement GMMA desc + enum { + BYTES_PER_BUFFER_INC_BOUNDARY_NO_4LSB = + BYTES_PER_BUFFER_NO_4LSB * BUFFERS_PER_TILE_ - BYTES_PER_BUFFER_NO_4LSB + }; + + // The number of buffers. + enum { BUFFERS_PER_TILE = BUFFERS_PER_TILE_ }; + + // The size in bytes of total buffers. + enum { BYTES_PER_TILE = BYTES_PER_BUFFER * BUFFERS_PER_TILE }; + + // The boundary for smem_read_offset and smem_write_offset increment. + enum { BYTES_PER_TILE_INC_BOUNDARY = BYTES_PER_TILE - BYTES_PER_BUFFER }; + + // The number of threads needed to store a row + enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_STS }; + + // The number of rows written with a single STS. + enum { ROWS_PER_STS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + + // for swizzle_128B the xor factor is 8 + enum { ROWS_PER_XOR_PATTERN = 8 }; + + // The distance in byte between different GMMA groups (might need multiple due to cta tile size) + // each GMMA group is of size GMMA_M x GMMA_N x Kblock + enum { + GMMA_GROUP_SMEM_DISTANCE = + Mma_tile::K_PER_GMMA_GROUP * NUM_ROWS_PER_GMMA_GROUP_N * BYTES_PER_ROW + }; + + // The number of STS per ROW. + enum { STS_PER_ROW = BYTES_PER_ROW / THREADS_PER_ROW / BYTES_PER_STS }; + + // For Hopper, STS_PER_ROW should be 1 (at least for now.) + static_assert(STS_PER_ROW == 1, ""); + + // Ctor. + inline __device__ Smem_tile_hopper_gmma_tma_row_b(char* smem, char* smem_barrier) + : smem_(__nvvm_get_smem_pointer(smem)), + smem_barrier_(__nvvm_get_smem_pointer(smem_barrier)) {} + + // Move the write offset to next buffer. + // Not implemented since it is not needed at the moment. + inline __device__ void move_next_write_buffer() {} + + inline __device__ void move_next_write_buffer(int) {} + + // Move the read offset to next buffer. + // do nothing, as it is controlled by gmma desc + inline __device__ void move_next_read_buffer() {} + + template + inline __device__ void store(cudaTmaDesc const* p_desc, int32_t const (&coord)[DIM], + uint16_t filter_offsets = 0, uint16_t mcast_cta_mask = 0) { + fmha::utmaldg(p_desc, smem_, smem_barrier_, coord); + } + + // The shared memory pointer. + uint32_t smem_; + // The barrier in smem. + uint32_t smem_barrier_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// A Row Major, A coming from SMEM +template +struct Smem_tile_hopper_a + : public Smem_tile_hopper_gmma_tma_row_a { + // The base class. + using Base = Smem_tile_hopper_gmma_tma_row_a; + + // Ctor. + inline __device__ Smem_tile_hopper_a(char* smem, char* smem_barrier) : Base(smem, smem_barrier) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// B Col Major, B coming from SMEM +template +struct Smem_tile_hopper_b + : public Smem_tile_hopper_gmma_tma_col_b { + // The base class. + using Base = Smem_tile_hopper_gmma_tma_col_b; + + // Ctor. + inline __device__ Smem_tile_hopper_b(char* smem, char* smem_barrier) : Base(smem, smem_barrier) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// B Row Major, B coming from SMEM +template +struct Smem_tile_hopper_b + : public Smem_tile_hopper_gmma_tma_row_b { + // The base class. + using Base = Smem_tile_hopper_gmma_tma_row_b; + + // Ctor. + inline __device__ Smem_tile_hopper_b(char* smem, char* smem_barrier) : Base(smem, smem_barrier) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace wip + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits_, + // The description of the tile computed by this CTA. + typename Cta_tile_, + // The layout of the tile. + typename Layout_, + // The number of bytes per STS. + int BYTES_PER_STS_, + // The number of buffers. (Used in multistage and double buffer cases.) + int BUFFERS_PER_TILE_, + // GMMA descriptor mode + fmha::Gmma_descriptor_mode desc_mode, + // Whether to use TMA. + bool USE_TMA, + // Whether A is coming for RF. + bool GMMA_A_RF = Traits_::GMMA_A_RF> +struct Smem_tile_hopper_a : public fmha::Smem_tile_without_skews< + Cta_tile_, Layout_::COL ? Cta_tile_::K : Cta_tile_::M, + Layout_::COL ? Cta_tile_::M : Cta_tile_::K, + Traits_::BITS_PER_ELEMENT_A, BYTES_PER_STS_, BUFFERS_PER_TILE_, 0, + (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_128B ? 8 + : desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_64B ? 4 + : 2), + 1, true, USE_TMA, 128 * 8 / Traits_::BITS_PER_ELEMENT_A> { + using Traits = Traits_; + using Cta_tile = Cta_tile_; + // The base class. + using Base = fmha::Smem_tile_without_skews< + Cta_tile, Layout_::COL ? Cta_tile::K : Cta_tile::M, Layout_::COL ? Cta_tile::M : Cta_tile::K, + Traits::BITS_PER_ELEMENT_A, BYTES_PER_STS_, BUFFERS_PER_TILE_, 0, + (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_128B ? 8 + : desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_64B ? 4 + : 2), + 1, true, USE_TMA, 128 * 8 / Traits::BITS_PER_ELEMENT_A>; + + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + // The layout + using Layout = Layout_; + // The fragment. + using Fragment = fmha::Fragment_a; + + // The number of desc within a gmma group (kblock limited) + static constexpr fmha::Gmma_descriptor_size GMMA_DESC_SIZE_PER_GROUP = + fmha::Gmma_descriptor_size::ONE; + // The SWIZZLE_128B descriptor + using Gmma_descriptor = + fmha::Gmma_descriptor_a; + + // the number of columns per one column of M_PER_GMMA_GROUP + enum { + NUM_COLS_PER_GMMA_GROUP_M = + (Mma_tile::M_PER_GMMA_GROUP * Base::BITS_PER_ELEMENT / 8 + Base::BYTES_PER_ROW - 1) / + Base::BYTES_PER_ROW + }; + + // The distance in byte between different GMMA groups (might need multiple due to cta tile size) + // each GMMA group is of size GMMA_M x GMMA_N x Kblock + static constexpr int GMMA_GROUP_SMEM_DISTANCE = + Layout::COL ? (Mma_tile::K_PER_GMMA_GROUP * NUM_COLS_PER_GMMA_GROUP_M * Base::BYTES_PER_ROW * + Cta_tile::WARP_GROUP_M) + : (Mma_tile::M_PER_GMMA_GROUP * Cta_tile::WARP_GROUP_M / + (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_128B ? 1 + : desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_64B ? 2 + : 4) * + Base::BYTES_PER_ROW); + + // the size of one buffer in bytes in shared memory, without the 4 LSB. + // this is needed to increment the GMMA desc to the next buffer + enum { BYTES_PER_BUFFER_NO_4LSB = Base::BYTES_PER_BUFFER / 16 }; + + // this is needed to decrement GMMA desc + enum { + BYTES_PER_BUFFER_INC_BOUNDARY_NO_4LSB = + BYTES_PER_BUFFER_NO_4LSB * BUFFERS_PER_TILE_ - BYTES_PER_BUFFER_NO_4LSB + }; + + // Ctor. + inline __device__ Smem_tile_hopper_a(void* smem, int tidx) : Base(smem, tidx) {} + + // set the scale and bias smem pointer + inline __device__ void set_scale_bias_smem_ptr(char* scale_bias_smem_ptr, int tidx, int k) {} + + // Load from shared memory. + template + inline __device__ void load(Fragment (&a)[Mma_tile::MMAS_M], int ki) {} + + // Move the read offset to next buffer. + // do nothing, as it is controlled by gmma desc + inline __device__ void move_next_read_buffer() {} + + // Overload set needs to be replicated for compatibility + inline __device__ void move_next_read_buffer(int N) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits_, + // The description of the tile computed by this CTA. + typename Cta_tile_, + // The layout of the tile. + typename Layout_, + // The number of bytes per STS. + int BYTES_PER_STS_, + // The number of buffers. (Used in multistage and double buffer cases.) + int BUFFERS_PER_TILE_, + // GMMA descriptor mode + fmha::Gmma_descriptor_mode desc_mode, + // USe TMA or not, + bool USE_TMA> +struct Smem_tile_hopper_b + : public fmha::Smem_tile_without_skews< + Cta_tile_, + Layout_::COL ? Cta_tile_::N : Cta_tile_::K, // ROWS + Layout_::COL ? Cta_tile_::K : Cta_tile_::N, // COLS + Traits_::BITS_PER_ELEMENT_B, BYTES_PER_STS_, BUFFERS_PER_TILE_, + 0, // LDS_FAST_PATH + // Determine ROWS_PER_XOR_PATTERN from the swizzle mode: + (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_128B ? 8 + : desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_64B ? 4 + : /* 32B or NONE */ 2), + 1, // COLS_PER_XOR_PATTERN + true, // USE_PREDICATES + USE_TMA, + 128 * 8 / Traits_::BITS_PER_ELEMENT_B // LEAD_DIM_ELEMENTS + > { + using Traits = Traits_; + using Cta_tile = Cta_tile_; + // The base class. + using Base = fmha::Smem_tile_without_skews< + Cta_tile, Layout_::COL ? Cta_tile::N : Cta_tile::K, Layout_::COL ? Cta_tile::K : Cta_tile::N, + Traits::BITS_PER_ELEMENT_B, BYTES_PER_STS_, BUFFERS_PER_TILE_, 0, + (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_128B ? 8 + : desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_64B ? 4 + : 2), + 1, true, USE_TMA, 128 * 8 / Traits::BITS_PER_ELEMENT_B>; + + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + // The layout + using Layout = Layout_; + // The fragment. + using Fragment = fmha::Fragment_b; + + // The number of desc within a gmma group (kblock limited) + static constexpr fmha::Gmma_descriptor_size GMMA_DESC_SIZE_PER_GROUP = + fmha::Gmma_descriptor_size::ONE; + // The SWIZZLE_128B descriptor + using Gmma_descriptor = + fmha::Gmma_descriptor_b; + + // the number of rows per one row of N_PER_GMMA_GROUP + enum { + NUM_ROWS_PER_GMMA_GROUP_N = + (Mma_tile::N_PER_GMMA_GROUP * Base::BITS_PER_ELEMENT / 8 + Base::BYTES_PER_ROW - 1) / + Base::BYTES_PER_ROW + }; + + // The distance in byte between different GMMA groups (might need multiple due to cta tile size) + // each GMMA group is of size GMMA_M x GMMA_N x Kblock + + // The dimension that we split. + // Add buffers when we have multiple buffers for split head dimensions. + // Split-d smem view (2 split D, and 3 buffers): d0, d0, d0, d1, d1, d1. + static constexpr int GMMA_GROUP_SPLIT_DIM = + Layout::COL ? Mma_tile::N_PER_GMMA_GROUP : (Mma_tile::K_PER_GMMA_GROUP * BUFFERS_PER_TILE_); + + // The split factor. + static constexpr int GMMA_GROUP_SPLIT_FACTOR = + (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_128B ? 1 + : desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_64B ? 2 + : 4); + + // Make sure the dimension that we split is a multiple of the split factor. + static_assert(GMMA_GROUP_SPLIT_DIM % GMMA_GROUP_SPLIT_FACTOR == 0); + + // The distance between two "groups" in shared memory. + static constexpr int GMMA_GROUP_SMEM_DISTANCE = + GMMA_GROUP_SPLIT_DIM / GMMA_GROUP_SPLIT_FACTOR * Base::BYTES_PER_ROW; + + // the size of one buffer in bytes in shared memory, without the 4 LSB. + // this is needed to increment the GMMA desc to the next buffer + enum { BYTES_PER_BUFFER_NO_4LSB = Base::BYTES_PER_BUFFER / 16 }; + + // this is needed to decrement GMMA desc + enum { + BYTES_PER_BUFFER_INC_BOUNDARY_NO_4LSB = + BYTES_PER_BUFFER_NO_4LSB * BUFFERS_PER_TILE_ - BYTES_PER_BUFFER_NO_4LSB + }; + + // Ctor. + inline __device__ Smem_tile_hopper_b(void* smem, int tidx) : Base(smem, tidx) { + warp_id_ = tidx / 32; + lane_id_ = tidx % 32; + + // each pair of warps transposes 8x8 in place + // each warp responsible for diagonal 4x4s + // calculate index in 8x8 block + block_row_ = lane_id_ / 4; + block_col_ = (lane_id_ % 4) + ((warp_id_ % 2) ^ (block_row_ / 4)) * 4; + + // diagonal 4x4s will 2x conflict for SWIZZLE_32B + // 1 warp per 8x8, 2 4x8 load+store + if (Traits::GMMA_N == 8) { + block_row_ = lane_id_ / 8; + block_col_ = lane_id_ % 8; + } + + // offset when all 4 warps participate in transpose + block_col_offset_ = (warp_id_ / 2) * 8; + } + + int warp_id_, lane_id_; + int block_row_, block_col_, block_col_offset_; + + // Load from shared memory. + inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) {} + + // Load from smem, do something (e.g. transpose), then store back to smem + inline __device__ void load_and_store(int ki) { + /* + using B_type = typename Traits::B_type; + + // TODO: move these to B_RF smem tiles + + // 8 channel per group fp16 fprop/dgrad with 64x16x16 gmma + // move 8x8 OOB zeros to right diagonal, 8x8 in-bounds weights on left diagonal + if (Cta_tile::N_PER_GROUP == 8 && Traits::GMMA_N == 16 + && Traits::BITS_PER_ELEMENT_B == 16) { + // just need to swap 2 cores within a single SWIZZLE_32B, one of which is just zero + // 1 LDSM.M88.1 + if (warp_id_ == 0) { + int smem_row_offset = ki * 4 * 128 + 2 * 128; // 4 rows per 16x16, swap the bottom 8x16 + int lds_block_idx = lane_id_ * 2; // ldsm.m88.1 only uses first 8 threads for address + int lds_smem_idx = lds_block_idx ^ (lane_id_ / 4); + + uint32_t data; + uint32_t lds_smem_ptr = this->smem_ + this->smem_read_buffer_ + + smem_row_offset + + lds_smem_idx * 16; + fmha::ldsm(data, lds_smem_ptr); + + __syncwarp(); + + // move values to adjacent core + fmha::stsm(lds_smem_ptr ^ 16, data); + + // set zeros at previous core + fmha::stsm(lds_smem_ptr, static_cast(0)); + } + } + + // 4 channel per group tf32 fprop with 64x8x8 gmma + // move 4x4 in-bounds weights on left diagonal, OOB zeros everywhere else + if (Cta_tile::N_PER_GROUP == 4 && Traits::GMMA_N == 8 + && Layout::COL && Traits::BITS_PER_ELEMENT_B == 32) { + // just need to swap the bottom 4x8, 1 elt per thread for 1 warp + // 1 lds/sts.32 per thread + if (warp_id_ == 0) { + int smem_row_offset = ki * Base::ROWS_PER_XOR_PATTERN * 128 + 128; + int lds_smem_idx = lane_id_; + uint32_t lds_ptr = this->smem_ + this->smem_read_buffer_ + + smem_row_offset + + lds_smem_idx * sizeof(B_type); + uint32_t data; + lds(data, lds_ptr); + + __syncwarp(); + + sts(lds_ptr ^ 16, data); + } + } + + // partial transpose of 8xN_PER_GROUP operand for tf32 grouped dgrad + // todo: revise this for tf32 grouped wgrad, move to partial specialization + static constexpr bool IS_TF32_GROUPED_DGRAD = + (Cta_tile::GROUPS_N > 1 && Cta_tile::GROUPS_K > 1 || Cta_tile::N_PER_GROUP == 32) + && Layout::ROW && Traits::BITS_PER_ELEMENT_B == 32; + if (IS_TF32_GROUPED_DGRAD) { + static constexpr int XOR_SCALE = 16 / sizeof(B_type); // 16B swizzle over 4B elements + static constexpr int ROWS_PER_128B = kDivUp( 128, Traits::GMMA_N * sizeof(B_type) ); + + if (Traits::GMMA_N == 8) { + if (warp_id_ == 0) { + + int smem_row_offset = ki * Base::ROWS_PER_XOR_PATTERN * 128; + uint32_t data[2]; + + #pragma unroll + for (int ii = 0; ii < 2; ii++) { + // get index in row-major 8x8 + int lds_block_row = block_row_ + ii * 4; + int lds_block_col = block_col_; + int lds_block_idx = lds_block_row * 8 + lds_block_col; + + // swizzle + int lds_xor_factor = (lds_block_row / ROWS_PER_128B) * XOR_SCALE; + int lds_smem_idx = lds_block_idx ^ lds_xor_factor; + + // Load from smem + uint32_t lds_ptr = this->smem_ + this->smem_read_buffer_ + + smem_row_offset + + lds_smem_idx * sizeof(B_type); + lds(data[ii], lds_ptr); + } + + __syncwarp(); + + #pragma unroll + for (int ii = 0; ii < 2; ii++) { + // get index in col-major 8x8 + int sts_block_row = block_col_; + int sts_block_col = block_row_ + ii * 4; + if (Cta_tile::N_PER_GROUP == 4 && ii == 1) { + // place 4x4 weights on diagonal for 4-channel tf32 group dgrad + sts_block_row ^= 4; + } + int sts_block_idx = sts_block_row * 8 + sts_block_col; + + // swizzle + int sts_xor_factor = (sts_block_row / ROWS_PER_128B) * XOR_SCALE; + int sts_smem_idx = sts_block_idx ^ sts_xor_factor; + + // store to smem + uint32_t sts_ptr = this->smem_ + this->smem_read_buffer_ + + smem_row_offset + + sts_smem_idx * sizeof(B_type); + sts(sts_ptr, data[ii]); + } + + } // warp_id == 0 + } else { + // loop over 8x16 blocks + #pragma unroll + for (int ii = 0; ii < kDivUp(Cta_tile::N_PER_GROUP, 16); ii++) { + int smem_row_offset = ki * Base::ROWS_PER_XOR_PATTERN * 128; + + // get index in row-major 8xN_PER_GROUP + int lds_block_row = block_row_; + int lds_block_col = block_col_ + block_col_offset_ + ii * 16; + int lds_block_idx = lds_block_row * Cta_tile::N_PER_GROUP + + lds_block_col; + + // swizzle + int lds_xor_factor = (lds_block_row / ROWS_PER_128B) * XOR_SCALE; + int lds_smem_idx = lds_block_idx ^ lds_xor_factor; + + // Load from smem + uint32_t lds_ptr = this->smem_ + this->smem_read_buffer_ + + smem_row_offset + + lds_smem_idx * sizeof(B_type); + uint32_t data; + lds(data, lds_ptr); + + __syncwarp(); + + // get index in row-major 8xN_PER_GROUP with 8x8 in-place transposes + int sts_block_row = block_col_; + int sts_block_col = block_row_ + block_col_offset_ + ii * 16; + int sts_block_idx = sts_block_row * Cta_tile::N_PER_GROUP + + sts_block_col; + + // swizzle + int sts_xor_factor = (sts_block_row / ROWS_PER_128B) * XOR_SCALE; + int sts_smem_idx = sts_block_idx ^ sts_xor_factor; + + // store to smem + uint32_t sts_ptr = this->smem_ + this->smem_read_buffer_ + + smem_row_offset + + sts_smem_idx * sizeof(B_type); + sts(sts_ptr, data); + } + } + } + + // make sure sts are visible to gmma + fence_view_async_shared(); + */ + } + + // Move the read offset to next buffer. + inline __device__ void move_next_read_buffer() {} + + // Move the read offset to next buffer. + inline __device__ void move_next_read_buffer(int buffer_id) { + this->smem_read_buffer_ = buffer_id * Base::BYTES_PER_BUFFER; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < // GMMA instruction shape in M dim + int GMMA_M, + // GMMA instruction shape in N dim + int GMMA_N, + // GMMA instruction shape in K dim + int GMMA_K, + // GMMA A operand coming from RF? + bool GMMA_A_RF, + // GMMA B operand coming from RF? + bool GMMA_B_RF, + // The description of the tile computed by this CTA. + typename Cta_tile, + // GMMA descriptor mode + fmha::Gmma_descriptor_mode desc_mode, + // Use TMA or not, + bool USE_TMA, int BUFFERS_PER_TILE> +struct Smem_tile_v, + Cta_tile, BUFFERS_PER_TILE, desc_mode, USE_TMA> + : public fmha::Smem_tile_hopper_b< + fmha::Hopper_hgmma_fp16_traits, Cta_tile, + fmha::Row, + 16, // BYTES_PER_STS + BUFFERS_PER_TILE, desc_mode, USE_TMA> { + static constexpr bool TRANSPOSE = false; + + using Cta_tile_gmma = Cta_tile; + + using Base = fmha::Smem_tile_hopper_b< + fmha::Hopper_hgmma_fp16_traits, Cta_tile, + fmha::Row, + 16, // BYTES_PER_STS + BUFFERS_PER_TILE, desc_mode, USE_TMA>; + + inline __device__ Smem_tile_v(void* smem, int tidx) : Base(smem, tidx) {} + + inline __device__ void transpose_tile(int) { + // Transpose is fused into HGMMA. + } + + inline __device__ void transpose_tile(int, uint32_t, uint32_t) { + // Transpose is fused into HGMMA. + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < // GMMA instruction shape in M dim + int GMMA_M, + // GMMA instruction shape in N dim + int GMMA_N, + // GMMA instruction shape in K dim + int GMMA_K, + // GMMA A operand coming from RF? + bool GMMA_A_RF, + // GMMA B operand coming from RF? + bool GMMA_B_RF, + // The description of the tile computed by this CTA. + typename Cta_tile, + // GMMA descriptor mode + fmha::Gmma_descriptor_mode desc_mode, + // Use TMA or not, + bool USE_TMA, int BUFFERS_PER_TILE> +struct Smem_tile_v, + Cta_tile, BUFFERS_PER_TILE, desc_mode, USE_TMA> + : public fmha::Smem_tile_hopper_b< + fmha::Hopper_hgmma_fp32_traits, Cta_tile, + fmha::Row, + 16, // BYTES_PER_STS + BUFFERS_PER_TILE, // BUFFERS_PER_TILE, + desc_mode, USE_TMA> { + static constexpr bool TRANSPOSE = false; + + using Cta_tile_gmma = Cta_tile; + + using Base = fmha::Smem_tile_hopper_b< + fmha::Hopper_hgmma_fp32_traits, Cta_tile, + fmha::Row, + 16, // BYTES_PER_STS + BUFFERS_PER_TILE, // BUFFERS_PER_TILE, + desc_mode, USE_TMA>; + + inline __device__ Smem_tile_v(void* smem, int tidx) : Base(smem, tidx) {} + + inline __device__ void transpose_tile(int) { + // Transpose is fused into HGMMA. + } + + inline __device__ void transpose_tile(int, uint32_t, uint32_t) { + // Transpose is fused into HGMMA. + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < // GMMA instruction shape in M dim + int GMMA_M, + // GMMA instruction shape in N dim + int GMMA_N, + // GMMA instruction shape in K dim + int GMMA_K, + // GMMA A operand coming from RF? + bool GMMA_A_RF, + // GMMA B operand coming from RF? + bool GMMA_B_RF, + // The description of the tile computed by this CTA. + typename Cta_tile, + // GMMA descriptor mode + fmha::Gmma_descriptor_mode desc_mode, + // Use TMA or not, + bool USE_TMA, int BUFFERS_PER_TILE> +struct Smem_tile_v, + Cta_tile, BUFFERS_PER_TILE, desc_mode, USE_TMA> + : public fmha::Smem_tile_hopper_b< + fmha::Hopper_hgmma_bf16_traits, Cta_tile, + fmha::Row, + 16, // BYTES_PER_STS + BUFFERS_PER_TILE, // BUFFERS_PER_TILE, + desc_mode, USE_TMA> { + static constexpr bool TRANSPOSE = false; + + using Cta_tile_gmma = Cta_tile; + + using Base = fmha::Smem_tile_hopper_b< + fmha::Hopper_hgmma_bf16_traits, Cta_tile, + fmha::Row, + 16, // BYTES_PER_STS + BUFFERS_PER_TILE, // BUFFERS_PER_TILE, + desc_mode, USE_TMA>; + + inline __device__ Smem_tile_v(void* smem, int tidx) : Base(smem, tidx) {} + + inline __device__ void transpose_tile(int) { + // Transpose is fused into HGMMA. + } + + inline __device__ void transpose_tile(int, uint32_t, uint32_t) { + // Transpose is fused into HGMMA. + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Transposer {}; + +template +struct Transposer { + static_assert(Cta_tile::K % 128 == 0); + + enum { + WARPS_M = Cta_tile::WARPS_M, + WARPS_N = Cta_tile::WARPS_N, + WARPS_K = Cta_tile::WARPS_K, + }; + + enum { + WARPS_4x1x1 = (WARPS_M == 4 && WARPS_N == 1 && WARPS_K == 1), + WARPS_4x1x2 = (WARPS_M == 4 && WARPS_N == 1 && WARPS_K == 2), + }; + + enum { BYTES_PER_LDS = 16 }; + + enum { BYTES_PER_ROW = 128 }; + + // D=64 and 4 warps. + // Per warp we load 32 rows x 16 columns with LDSM.Tx4, 128 rows per CTA. + enum { S = Cta_tile::K >= 128 ? 128 : Cta_tile::K }; // The sequence length. + + enum { D = Cta_tile::N >= 128 ? 128 : Cta_tile::N }; // The head dimension. + + // static_assert(S % 128 == 0); + static_assert(WARPS_4x1x1 || WARPS_4x1x2); + static_assert(D % (BYTES_PER_LDS * WARPS_K) == 0); + + enum { ROWS_PER_LDSM_PER_CTA_WITHOUT_PACKING = 128 }; // LDSMx4 + + enum { ROW_PACKING = BYTES_PER_ROW / (D * sizeof(typename Traits::B_type)) }; + + enum { ROWS_PER_LDSM_PER_CTA = ROWS_PER_LDSM_PER_CTA_WITHOUT_PACKING / ROW_PACKING }; + + enum { ROWS_PER_XOR_PATTERN = fmha::Rows_per_xor_pattern_ampere_b::VALUE }; + + static_assert(ROWS_PER_XOR_PATTERN == 8); + + // The number of loads in K dimension. + enum { K = S / ROWS_PER_LDSM_PER_CTA_WITHOUT_PACKING }; + + // static_assert(K * ROWS_PER_LDSM_PER_CTA_WITHOUT_PACKING == S); + // static_assert(K == 3); + // The number of loads in the D dimension. + enum { N = D / (BYTES_PER_LDS * WARPS_K) }; // 16 bytes per load + + static_assert(N * BYTES_PER_LDS * WARPS_K == D); + + uint4 regs_[UNROLL_N][K]; + + uint32_t read_offset_; + uint32_t write_offset_; + uint32_t smem_read_loc_; + uint32_t smem_write_loc_; + + inline __device__ Transposer(int tidx) { + int read_row, read_col; + + if (WARPS_4x1x1 && N == 8) { // D=128, 1 warp in N + read_row = (tidx & 0x7f); + read_col = (tidx & 0x07); + } else if (WARPS_4x1x1 && N == 4) { // D=64, 1 warp in N + read_row = (tidx & 0xe0) / 2 + (tidx & 0x1e) / 2; + read_col = (tidx & 0x01) * 4 + (tidx & 0x06) / 2; + } else if (WARPS_4x1x1 && N == 2) { // D=32, 1 warp in N + read_row = (tidx & 0x60) / 4 + (tidx & 0x1c) / 4; + read_col = (tidx & 0x03) * 2; + read_col ^= (read_row & 0x01); + } else if (WARPS_4x1x2 && N == 4) { // D=128, 2 warps in N + read_row = (tidx & 0x7f); + read_col = (tidx & 0x07); + // For two warpgroups we do two steps in N at once. + read_col ^= (tidx & 0x80) / 128; + } else if (WARPS_4x1x2 && N == 2) { // D=64, 2 warps in N + read_row = (tidx & 0x60) / 2 + (tidx & 0x1e) / 2; + read_col = (tidx & 0x01) * 4 + (tidx & 0x06) / 2; + // For two warpgroups we do two steps in N at once. + read_col ^= (tidx & 0x80) / 128; + } else if (WARPS_4x1x2 && N == 1) { // D=32, 2 warps in N + read_row = (tidx & 0x60) / 4 + (tidx & 0x1c) / 4; + read_col = (tidx & 0x03) * 2; + read_col ^= (read_row & 0x01); + // For two warpgroups we do two steps in N at once. + read_col ^= (tidx & 0x80) / 128; + } else { + assert(false); + } + + read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; + + int write_row, write_col; + if (WARPS_4x1x1) { // swizzle_128byte + write_row = (tidx & 0x10) / 2 + (tidx & 0x07); + write_col = (tidx & 0x60) / 16 + (tidx & 0x08) / 8; + } else if (WARPS_4x1x2) { + // Same as above, with second warp group writing next 16 rows. + write_row = (tidx & 0x80) / 8 + (tidx & 0x10) / 2 + (tidx & 0x07); + write_col = (tidx & 0x60) / 16 + (tidx & 0x08) / 8; + } else { + assert(false); + } + + write_col ^= (write_row & 0x07); + + write_offset_ = write_row * BYTES_PER_ROW + write_col * BYTES_PER_LDS; + } + + inline __device__ void transpose(int tidx, uint32_t smem) { transpose_(tidx, smem, smem); } + + template + inline __device__ void transpose_(uint32_t smem_src, uint32_t smem_dst) { +#pragma unroll + for (int n_begin = 0; n_begin < N; n_begin += UNROLL_N) { + transpose_ldmatrix(n_begin, smem_src); + transpose_stmatrix(n_begin, smem_dst); + } + } + + inline __device__ void transpose_ldmatrix(int n_begin, uint32_t smem_src) { + static_assert(N % UNROLL_N == 0, ""); + + uint4 tmp[UNROLL_N][K]; + if (n_begin == 0) { + smem_read_loc_ = smem_src + read_offset_; + } + +#pragma unroll + for (int ni = n_begin; ni < n_begin + UNROLL_N; ni++) { + int const nii = ni - n_begin; +#pragma unroll + for (int ki = 0; ki < K; ki++) { // 2 + fmha::ldsmt(tmp[nii][ki], smem_read_loc_ + ki * ROWS_PER_LDSM_PER_CTA * BYTES_PER_ROW); + } + + if (WARPS_4x1x1 && N == 4) { // D=64, 1 warp in N + smem_read_loc_ ^= (ni % 2 == 0 ? 1 : 3) * 16; + } else if (WARPS_4x1x1 && N == 2) { // D=32, 1 warp in N + smem_read_loc_ ^= 16; + } else if (WARPS_4x1x2 && N == 2) { // D=64, 2 warps in N + smem_read_loc_ ^= 32; + } else if (WARPS_4x1x2 && N == 4) { // D=128, 2 warps in N + smem_read_loc_ ^= (ni % 2 == 0 ? 1 : 3) * 32; + } else if (WARPS_4x1x1 && N == 8) { // D=128, 1 warp in N + smem_read_loc_ ^= ((ni % 4 == 3) ? 7 : (ni % 2 == 1 ? 3 : 1)) * 16; + } else if (N != 1) { + assert(false); + } + } + +#pragma unroll + for (int ni = n_begin; ni < n_begin + UNROLL_N; ni++) { + int const nii = ni - n_begin; +#pragma unroll + for (int ki = 0; ki < K; ki++) { + fmha::swizzle_rows(regs_[nii][ki].x, regs_[nii][ki].z, tmp[nii][ki].x, + tmp[nii][ki].y); // PRMT 0+1 + fmha::swizzle_rows(regs_[nii][ki].y, regs_[nii][ki].w, tmp[nii][ki].z, + tmp[nii][ki].w); // PRMT 2+3 + } + } + } + + template + inline __device__ void transpose_stmatrix(int n_begin, uint32_t smem_dst) { + // After LDSM.Tx4 registers hold 2x2 elts: + // [00, 01] + // [10, 11] + // With row offsets + // x: + 0 + // y: + 8 + // z: +16 (g) + // w: +24 (o) + // + // After PRMT 0, the : + // [00, 01] [80, 81] => x: [00, 10, 80, 90], i.e. col 0 + // [10, 11] [90, 91] => z: [01, 11, 81, 91], i.e. col 1 + // + // [g0, g1] [o0, o1] => y: [g0, h0, o0, p0], i.e. col 0 + // [h0, h1] [p0, p1] => w: [g1, h1, o1, p1], i.e. col 1 + // + // Therefore, when looking at the transpose, quad q holds cols 2 * q + [0, 1], i.e. + // - quad 0 holds cols 0, 1 + // - quad 1 holds cols 2, 3 + // - etc. + // + // This fits with the accumulator layout, since N strides in steps of 8 per thread. + + if (SYNC) { // needed if src and dst are the same. + __syncthreads(); // LDSM.T done. We should now have a D x S tile in registers. SMEM can be + // written. + } + + if (n_begin == 0) { + smem_write_loc_ = smem_dst + write_offset_; + } + +#pragma unroll + for (int ni = n_begin; ni < n_begin + UNROLL_N; ni++) { + int const nii = ni - n_begin; +#pragma unroll + for (int ki = 0; ki < K; ki++) { + fmha::stsm(smem_write_loc_ + ki * BYTES_PER_ROW * D, regs_[nii][ki]); + } + if (WARPS_4x1x1) { // D=64, 1 warp in N. + smem_write_loc_ += 16 * BYTES_PER_ROW; + } else if (WARPS_4x1x2) { // D=64, 2 warps in N. + smem_write_loc_ += 32 * BYTES_PER_ROW; + } else { + assert(false); + } + } + } +}; + +template +struct Transposer { + static_assert(Cta_tile::K % 64 == 0); + + enum { + WARPS_M = Cta_tile::WARPS_M, + WARPS_N = Cta_tile::WARPS_N, + WARPS_K = Cta_tile::WARPS_K, + }; + + enum { + WARPS_4x1x1 = (WARPS_M == 4 && WARPS_N == 1 && WARPS_K == 1), + WARPS_4x1x2 = (WARPS_M == 4 && WARPS_N == 1 && WARPS_K == 2), + }; + + enum { BYTES_PER_LDS = 16 }; + + // D=64 and 4 warps. + // Per warp we load 32 rows x 16 columns with LDSM.Tx4, 128 rows per CTA. + enum { S = Cta_tile::K >= 128 ? 128 : Cta_tile::K }; // The sequence length. + + enum { D = Cta_tile::N >= 128 ? 128 : Cta_tile::N }; // The head dimension. + + static_assert(S % 64 == 0); + static_assert(WARPS_4x1x1); + static_assert(D % 32 == 0); + + static_assert(S == 64 && D == 128); + + // Two warps in S dim. + enum { ROWS_PER_LDSM_PER_CTA_WITHOUT_PACKING = 64 }; // LDSMx4 + + enum { BYTES_PER_ROW = 128 }; + + enum { ROW_PACKING = Div_up::VALUE }; + + enum { + ROWS_PER_LDSM_PER_CTA = ROWS_PER_LDSM_PER_CTA_WITHOUT_PACKING / ROW_PACKING + }; // due to row_packing + + // The number of loads in K dimension. + enum { K = S / ROWS_PER_LDSM_PER_CTA_WITHOUT_PACKING }; + + // The number of loads in the D dimension. Use two warps in D dim. + enum { N = D / 32 }; + + uint4 regs_[UNROLL_N][K]; + + uint32_t read_offset_; + uint32_t write_offset_; + uint32_t smem_read_loc_; + uint32_t smem_write_loc_; + + inline __device__ Transposer(int tidx) { + int read_row, read_col; + + if (WARPS_4x1x1 && N == 1) { // D=32, 2 warps in N + read_row = (tidx & 0x20) / 4 + (tidx & 0x1c) / 4; + read_col = (tidx & 0x03) * 2; + read_col ^= (read_row & 0x01); + read_col ^= ((tidx & 0x40) / 64); + } else if (WARPS_4x1x1 && N == 2) { // D=64, 2 warps in N + read_row = (tidx & 0x20) / 2 + (tidx & 0x1e) / 2; + read_col = (tidx & 0x01) * 4 + (tidx & 0x06) / 2; + read_col ^= ((tidx & 0x40) / 64); + } else if (WARPS_4x1x1 && N == 4) { // D=128, 2 warps in N + read_row = (tidx & 0x3f); + read_col = (tidx & 0x07); + read_col ^= ((tidx & 0x40) / 64); + } else { + assert(false); + } + + read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; + + // static_assert(ROWS_PER_LDSM_PER_CTA == 32); + // constexpr int ROWS_PER_XOR_PATTERN = 4; + // constexpr int ROWS_PER_XOR_PATTERN = fmha::Rows_per_xor_pattern_ampere_b::VALUE; + + int row, col; + if (WARPS_4x1x1) { + row = (tidx & 0x40) / 4 + (tidx & 0x10) / 2 + (tidx & 0x07); + col = (tidx & 0x20) / 16 + (tidx & 0x08) / 8; + col = col + (row % 2) * 4; + row = row / 2; + col = col ^ (row % 4); + } else { + assert(false); + } + write_offset_ = row * BYTES_PER_ROW + col * BYTES_PER_LDS; + }; + + inline __device__ void transpose(int tidx, uint32_t smem) { transpose_(tidx, smem, smem); } + + template + inline __device__ void transpose_(uint32_t smem_src, uint32_t smem_dst) { +#pragma unroll + for (int n_begin = 0; n_begin < N; n_begin += UNROLL_N) { + transpose_ldmatrix(n_begin, smem_src); + transpose_stmatrix(n_begin, smem_dst); + } + } + + inline __device__ void transpose_ldmatrix(int n_begin, uint32_t smem_src) { + static_assert(N % UNROLL_N == 0, ""); + + uint4 tmp[UNROLL_N][K]; + if (n_begin == 0) { + smem_read_loc_ = smem_src + read_offset_; + } +#pragma unroll + for (int ni = n_begin; ni < n_begin + UNROLL_N; ni++) { +#pragma unroll + for (int ki = 0; ki < K; ki++) { + int const nii = ni - n_begin; + fmha::ldsmt(tmp[ni][ki], smem_read_loc_ + ki * ROWS_PER_LDSM_PER_CTA * BYTES_PER_ROW); + } + + if (WARPS_4x1x1 && N == 2) { // D=64, 2 warps in N + smem_read_loc_ ^= 32; + } else if (WARPS_4x1x1 && N == 4) { // D=128, 2 warps in N + smem_read_loc_ ^= (ni % 2 == 1 ? 3 * 32 : 32); + } else if (N != 1) { + assert(false); + } + } + +#pragma unroll + for (int ni = n_begin; ni < n_begin + UNROLL_N; ni++) { + int const nii = ni - n_begin; +#pragma unroll + for (int ki = 0; ki < K; ki++) { + fmha::swizzle_rows(regs_[nii][ki].x, regs_[nii][ki].z, tmp[nii][ki].x, + tmp[nii][ki].y); // PRMT 0+1 + fmha::swizzle_rows(regs_[nii][ki].y, regs_[nii][ki].w, tmp[nii][ki].z, + tmp[nii][ki].w); // PRMT 2+3 + } + } + } + + template + inline __device__ void transpose_stmatrix(int n_begin, uint32_t smem_dst) { + // After LDSM.Tx4 registers hold 2x2 elts: + // [00, 01] + // [10, 11] + // With row offsets + // x: + 0 + // y: + 8 + // z: +16 (g) + // w: +24 (o) + // + // After PRMT 0, the : + // [00, 01] [80, 81] => x: [00, 10, 80, 90], i.e. col 0 + // [10, 11] [90, 91] => z: [01, 11, 81, 91], i.e. col 1 + // + // [g0, g1] [o0, o1] => y: [g0, h0, o0, p0], i.e. col 0 + // [h0, h1] [p0, p1] => w: [g1, h1, o1, p1], i.e. col 1 + // + // Therefore, when looking at the transpose, quad q holds cols 2 * q + [0, 1], i.e. + // - quad 0 holds cols 0, 1 + // - quad 1 holds cols 2, 3 + // - etc. + // + // This fits with the accumulator layout, since N strides in steps of 8 per thread. + + if (SYNC) { + __syncthreads(); // LDSM.T done. We should now have a D x S tile in registers. SMEM can be + // written. + } + + if (n_begin == 0) { + smem_write_loc_ = smem_dst + write_offset_; + } + +#pragma unroll + for (int ni = n_begin; ni < n_begin + UNROLL_N; ni++) { + int const nii = ni - n_begin; +#pragma unroll + for (int ki = 0; ki < K; ki++) { + fmha::stsm(smem_write_loc_ + ki * BYTES_PER_ROW * D / 2, regs_[nii][ki]); + } + if (WARPS_4x1x1) { // D=64, 1 warp in N. + smem_write_loc_ += 16 * BYTES_PER_ROW; + } else { + assert(false); + } + } + } +}; + +template < + // The instruction traits. + typename Traits, + // The Cta_tile. + typename Cta_tile, + // The number of buffers. + int BUFFERS_PER_TILE, + // GMMA descriptor mode + fmha::Gmma_descriptor_mode desc_mode, + // USe TMA or not, + bool USE_TMA> +struct Smem_tile_v_gmma { + static_assert(sizeof(typename Traits::B_type) == 1); + + // K is the sequence length dimension (128 for GMMA) + enum { K_ = Cta_tile::K % 128 == 0 ? 128 : 64 }; + + static_assert(Cta_tile::K % K_ == 0); + + // static_assert(Cta_tile::N == 128); + // static_assert(K_ == 128); + // static_assert(BUFFERS_PER_TILE == 2); + + using Cta_tile_gmma_ = + typename Traits::template Cta_tile; + + // TODO Swizzle_32B? + static constexpr fmha::Gmma_descriptor_mode GMMA_DESC_MODE_V = + Cta_tile_gmma_::K * sizeof(typename Traits::B_type) >= 128 + ? fmha::Gmma_descriptor_mode::SWIZZLE_128B + : fmha::Gmma_descriptor_mode::SWIZZLE_64B; + + static_assert( + (Cta_tile::K % 128 == 0 && GMMA_DESC_MODE_V == fmha::Gmma_descriptor_mode::SWIZZLE_128B) || + (Cta_tile::K % 64 == 0 && GMMA_DESC_MODE_V == fmha::Gmma_descriptor_mode::SWIZZLE_64B)); + + enum { NUM_KGROUPS = Cta_tile::K / Cta_tile_gmma_::K }; + + static_assert(NUM_KGROUPS * Cta_tile_gmma_::K == Cta_tile::K); + + enum { BYTES_PER_STS = 16 }; + + // The compute tile only requires static information from Smem_tile_v and accesses SMEM directly + // through GMMA. Hence, we declare a SxD column major matrix in SMEM and have to make sure at + // runtime that the data is transposed. Note that for K > 128, we are using two buffers per tile, + // which we have to fill accordingly. + using Base_ = fmha::Smem_tile_hopper_b; + + // Split D or not, which influences the GMMA_GROUP_SMEM_DISTANCE, and BYTES_PER_BUFFER_NO_4LSB. + // Split-d smem view (2 split D, and 3 buffers): d0, d0, d0, d1, d1, d1. + // The group distance would be number_of_buffers * buffer_size. + // The buffer size is the size for split-d. + static constexpr size_t GMMA_GROUP_SMEM_DISTANCE = + Base_::GMMA_GROUP_SMEM_DISTANCE * BUFFERS_PER_TILE; + static constexpr size_t BYTES_PER_BUFFER_NO_4LSB = Base_::BYTES_PER_BUFFER_NO_4LSB; + + using Gmma_descriptor = typename Base_::Gmma_descriptor; + + struct Base : public Base_ { + using Transposer = Transposer; + static_assert(USE_TMA == false); + static constexpr bool TRANSPOSE = true; + + enum { NUM_KGROUPS = Cta_tile::K / Cta_tile_gmma_::K }; + + enum { ROWS_PER_XOR_PATTERN = fmha::Rows_per_xor_pattern_ampere_b::VALUE }; + + using Descriptor = typename Base_::Gmma_descriptor; + + // Delegate all the stores to the Row-Major Smem_tile. + using Store_delegate = Smem_tile_without_skews; + + using Store_type = typename Store_delegate::Store_type; + + enum { S = Cta_tile::K }; + + // static_assert(Descriptor::BYTES_PER_LEADING_DIM == 128); + // static_assert(Descriptor::STRIDE_BYTE_OFFSET == K_ * 8 / 16); // 128 * 8 / 16 + // static_assert(Descriptor::TRANS_MODE == fmha::Gmma_descriptor_transpose::NOTRANS); + // static_assert(Base::BYTES_PER_TILE == S * 64); + // static_assert(!Descriptor::LEADING_BYTE_OFFSET_NEEDED); + // static_assert(Descriptor::LEADING_BYTE_OFFSET == 128 * 64 / 16); + // static_assert(Descriptor::BYTES_PER_DESC_NO_4LSB == 32 * 1 / 16); + // static_assert(Descriptor::BYTES_DESC_INC_BOUNDARY_NO_4LSB == (K_ / 32 - 1) * 2); + // static_assert(Base::BYTES_PER_BUFFER_NO_4LSB == K_ * 64 / 16); + // static_assert(Base::GMMA_GROUP_SMEM_DISTANCE == 128 * 128 * 2); + // static_assert(Base::BYTES_PER_BUFFER_NO_4LSB == 128 * 128); + + // static_assert(Store_delegate::N_WITH_PADDING == 64); + // static_assert(Store_delegate::ROWS_PER_XOR_PATTERN == 4); + // static_assert(Store_delegate::BYTES_PER_ROW_BEFORE_PACKING == 64); + // static_assert(Store_delegate::ROWS == S / 2); + // static_assert(Store_delegate::BYTES_PER_ROW == 128); + + // Number of rows a warp loads per LDSMx4 + enum { ROWS_PER_LDSM = 4 * 8 }; + + enum { ROWS_PER_LDSM_PER_CTA = ROWS_PER_LDSM * Cta_tile::WARPS_M }; + + static_assert(Cta_tile::WARPS_M == 4); + + enum { LDSMS = Cta_tile::K / ROWS_PER_LDSM_PER_CTA }; + + // TODO we're assigning all rows loaded by a warp group (128 per CTA) to the K dimension. + // This only works for K a multiple of 128. + // For S=192, we want 3 blocks of 64xD. + // static_assert(LDSMS * ROWS_PER_LDSM_PER_CTA == Cta_tile::K); + + static_assert(LDSMS == S / 128); + + enum { BYTES_PER_LDS = 16 }; + + enum { BYTES_PER_ROW = Store_delegate::BYTES_PER_ROW }; + + enum { + WARPS_M = Cta_tile::WARPS_M, + WARPS_N = Cta_tile::WARPS_N, + WARPS_K = Cta_tile::WARPS_K, + }; + + enum { + WARPS_4x1x1 = (WARPS_M == 4 && WARPS_N == 1 && WARPS_K == 1), + WARPS_4x1x2 = (WARPS_M == 4 && WARPS_N == 1 && WARPS_K == 2), + }; + + inline __device__ Base(void* smem, int tidx) + : Base_(smem, tidx), delegate(smem, tidx), transposer(tidx) {} + + // Store to the tile in shared memory. + template + inline __device__ void store(Store_type const (&data)[N]) { + uint32_t smem_ptrs[N]; + delegate.compute_store_pointers(smem_ptrs); + sts(smem_ptrs, data); + } + + // Store to the tile in shared memory. + template + inline __device__ void store(Store_type const (&data)[N], uint32_t (&preds)[M]) { + uint32_t smem_ptrs[N]; + delegate.compute_store_pointers(smem_ptrs); + sts(smem_ptrs, data, preds); + } + + // Store to the tile in shared memory. + template + inline __device__ void store(Store_type const (&data)[N], uint32_t preds) { + delegate.store(data, preds); + } + + // Store to the tile in shared memory. + template + inline __device__ void store(void const* (&gmem_ptrs)[N], uint32_t (&preds)[M]) { + uint32_t smem_ptrs[N]; + delegate.compute_store_pointers(smem_ptrs); + ldgsts(smem_ptrs, gmem_ptrs, preds); + } + + // Store to the tile in shared memory. + template + inline __device__ void store(void const* (&gmem_ptrs)[N], uint32_t preds, uint64_t = 0) { + uint32_t tmp[1] = {preds}; + delegate.store(gmem_ptrs, tmp); + } + + // Store to the tile in shared memory. + template + inline __device__ void store(void const* (&gmem_ptrs)[N], uint32_t preds) { + uint32_t tmp[1] = {preds}; + delegate.store(gmem_ptrs, tmp); + } + + // Initial offset (via tidx) has been moved to ctor + inline __device__ void transpose_tile(int /* tidx */) { transposer.transpose(0, this->smem_); } + + template + inline __device__ void transpose_tile(uint32_t smem_src, uint32_t smem_dst) { + transposer.template transpose_(smem_src, smem_dst); + } + + inline __device__ void transpose_tile_ldmatrix(int, uint32_t smem) { + transposer.transpose_ldmatrix(0, smem); + } + + inline __device__ void transpose_tile_stmatrix(int, uint32_t smem) { + transposer.template transpose_stmatrix(0, smem); + } + + inline __device__ void transpose_tile_128(int tidx) { + // D=64 and 4 warps. + // Per warp we load 32 rows x 16 columns with LDSM.Tx4, 128 rows per CTA. + constexpr int S = Cta_tile::K; // The sequence length. + constexpr int D = Cta_tile::N; // The head dimension. + // static_assert(S == 256); + static_assert(D == 64); + // static_assert(S % 128 == 0); + static_assert(WARPS_4x1x1 || WARPS_4x1x2); + static_assert(D % (16 * WARPS_K) == 0); + + constexpr int ROWS_PER_LDSM_PER_CTA_WITHOUT_PACKING = 128; // LDSMx4 + constexpr int BYTES_PER_ROW = 128; + constexpr int ROW_PACKING = BYTES_PER_ROW / (D * sizeof(Traits::B_type)); + + // The number of loads in K dimension. + constexpr int K = S / ROWS_PER_LDSM_PER_CTA_WITHOUT_PACKING; + // static_assert(K * ROWS_PER_LDSM_PER_CTA_WITHOUT_PACKING == S); + // static_assert(K == 3); + // The number of loads in the D dimension. + constexpr int N = D / (16 * WARPS_K); + static_assert(N * 16 * WARPS_K == D); + + int read_row, read_col; + + if (WARPS_4x1x1 && N == 4) { // D=64, 1 warp in N + read_row = (tidx & 0xe0) / 2 + (tidx & 0x1e) / 2; + read_col = (tidx & 0x01) * 4 + (tidx & 0x06) / 2; + } else if (WARPS_4x1x2 && N == 2) { // D=64, 2 warps in N + read_row = (tidx & 0x60) / 2 + (tidx & 0x1e) / 2; + read_col = (tidx & 0x01) * 4 + (tidx & 0x06) / 2; + // For two warpgroups we do two steps in N at once. + read_col ^= (tidx & 0x80) / 128; + } else { + assert(false); + } + + uint32_t offset = read_row * BYTES_PER_ROW + read_col * 16; + + constexpr int ROWS_PER_LDSM_PER_CTA = + ROWS_PER_LDSM_PER_CTA_WITHOUT_PACKING / ROW_PACKING; // due to row_packing + + uint4 tmp[N][K]; + uint32_t smem_tmp = this->smem_; //__nvvm_get_smem_pointer(v_smem_) ; + uint32_t smem_loc = smem_tmp + offset; + +#pragma unroll + for (int ni = 0; ni < N; ni++) { +#pragma unroll + for (int ki = 0; ki < K; ki++) { + fmha::ldsmt(tmp[ni][ki], smem_loc + ki * ROWS_PER_LDSM_PER_CTA * BYTES_PER_ROW); + } + + if (WARPS_4x1x1 && N == 4) { // D=64, 1 warp in N + smem_loc ^= (ni % 2 == 0 ? 1 : 3) * 16; + } else if (WARPS_4x1x2 && N == 2) { // D=64, 2 warps in N + smem_loc ^= 32; + } else { + assert(false); + } + } + + uint4 regs[N][K]; + +#pragma unroll + for (int ni = 0; ni < N; ni++) { +#pragma unroll + for (int ki = 0; ki < K; ki++) { + fmha::swizzle_rows(regs[ni][ki].x, regs[ni][ki].z, tmp[ni][ki].x, + tmp[ni][ki].y); // PRMT 0+1 + fmha::swizzle_rows(regs[ni][ki].y, regs[ni][ki].w, tmp[ni][ki].z, + tmp[ni][ki].w); // PRMT 2+3 + } + } + + // After LDSM.Tx4 registers hold 2x2 elts: + // [00, 01] + // [10, 11] + // With row offsets + // x: + 0 + // y: + 8 + // z: +16 (g) + // w: +24 (o) + // + // After PRMT 0, the : + // [00, 01] [80, 81] => x: [00, 10, 80, 90], i.e. col 0 + // [10, 11] [90, 91] => z: [01, 11, 81, 91], i.e. col 1 + // + // [g0, g1] [o0, o1] => y: [g0, h0, o0, p0], i.e. col 0 + // [h0, h1] [p0, p1] => w: [g1, h1, o1, p1], i.e. col 1 + // + // Therefore, when looking at the transpose, quad q holds cols 2 * q + [0, 1], i.e. + // - quad 0 holds cols 0, 1 + // - quad 1 holds cols 2, 3 + // - etc. + // + // This fits with the accumulator layout, since N strides in steps of 8 per thread. + + __syncthreads(); // LDSM.T done. We should now have a D x S tile in registers. SMEM can be + // written. + constexpr int ROWS_PER_XOR_PATTERN = fmha::Rows_per_xor_pattern_ampere_b::VALUE; + static_assert(ROWS_PER_XOR_PATTERN == 8); + + int row, col; + if (WARPS_4x1x1) { + row = (tidx & 0x10) / 2 + (tidx & 0x07); + col = (tidx & 0x60) / 16 + (tidx & 0x08) / 8; + } else if (WARPS_4x1x2) { + // Same as above, with second warp group writing next 16 rows. + row = (tidx & 0x80) / 8 + (tidx & 0x10) / 2 + (tidx & 0x07); + col = (tidx & 0x60) / 16 + (tidx & 0x08) / 8; + } else { + assert(false); + } + col ^= (row & 0x07); + int dst = smem_tmp + row * BYTES_PER_ROW + col * BYTES_PER_LDS; + +#pragma unroll + for (int ni = 0; ni < N; ni++) { +#pragma unroll + for (int ki = 0; ki < K; ki++) { + fmha::stsm(dst + ki * BYTES_PER_ROW * D, regs[ni][ki]); + } + if (WARPS_4x1x1 && N == 4) { // D=64, 1 warp in N. + dst += 16 * BYTES_PER_ROW; + } else if (WARPS_4x1x2 && N == 2) { // D=64, 2 warps in N. + dst += 32 * BYTES_PER_ROW; + } else { + assert(false); + } + } + } + + Store_delegate delegate; + Transposer transposer; + }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_v< + fmha::Hopper_qgmma_e4m3_fp32_traits, Cta_tile, + BUFFERS_PER_TILE, desc_mode, USE_TMA> + : public Smem_tile_v_gmma< + fmha::Hopper_qgmma_e4m3_fp32_traits, + Cta_tile, BUFFERS_PER_TILE, desc_mode, USE_TMA>::Base { + using Traits = fmha::Hopper_qgmma_e4m3_fp32_traits; + + using Base = + typename fmha::Smem_tile_v_gmma::Base; + + inline __device__ Smem_tile_v(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_v< + fmha::Hopper_igmma_int8_int32_traits, Cta_tile, + BUFFERS_PER_TILE, desc_mode, USE_TMA> + : public Smem_tile_v_gmma< + fmha::Hopper_igmma_int8_int32_traits, + Cta_tile, BUFFERS_PER_TILE, desc_mode, USE_TMA>::Base { + using Traits = fmha::Hopper_igmma_int8_int32_traits; + + using Base = + typename fmha::Smem_tile_v_gmma::Base; + + inline __device__ Smem_tile_v(void* smem, int tidx) : Base(smem, tidx) {} +}; + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/hopper/smem_tile_o.h b/csrc/fmha_v2/fmha/hopper/smem_tile_o.h new file mode 100644 index 0000000000..cd499a5f39 --- /dev/null +++ b/csrc/fmha_v2/fmha/hopper/smem_tile_o.h @@ -0,0 +1,325 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Smem_tile_o_dummy { + enum { BYTES_PER_TILE = 0 }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_o_gmma_32bit_8bit : public Smem_tile_o_base_8bit_mma { + // The base class. + using Base = Smem_tile_o_base_8bit_mma; + + using Mma_tile = typename Base::Mma_tile; + using Accumulator = typename Base::Accumulator; + + enum { + BYTES_PER_ROW = Base::BYTES_PER_ROW, + BYTES_PER_ROW_WITH_PACKING = Base::BYTES_PER_ROW_WITH_PACKING, + LOOPS = Base::LOOPS, + LDS_PER_LOOP = Base::LDS_PER_LOOP, + ROWS_PER_LDS = Base::ROWS_PER_LDS, + HAS_INCOMPLETE_LDS = Base::HAS_INCOMPLETE_LDS, + }; + + // Ctor. + inline __device__ Smem_tile_o_gmma_32bit_8bit(void* smem, int tidx) : Base(smem, tidx) {} + + // Store the accumulators. + inline __device__ void store(Accumulator const (&acc)[1][1], int mi) { + enum { M_PER_MMA = Mma_tile::M_PER_MMA_PER_CTA }; + + static_assert(M_PER_MMA == 64); + static_assert(Base::WARPS_4x1x2); + + // The number of MMAs that are stored per loop iteration. + enum { MMAS_M_PER_LOOP = Mma_tile::MMAS_M / LOOPS }; + + static_assert(MMAS_M_PER_LOOP == 1); + static_assert(Mma_tile::MMAS_N == 1); + static_assert(Mma_tile::CORES_N == 8); + static_assert(Accumulator::NUM_REGS == Mma_tile::CORES_N / 2 * 8); + static_assert(BYTES_PER_ROW == 64 * 4); + static_assert(Cta_tile::WARPS_K == 2); + + static_assert(Mma_tile::CORES_N / 2 == 4); + +#pragma unroll + for (int ni = 0; ni < Mma_tile::CORES_N / 2; ++ni) { +#pragma unroll + for (int mj = 0; mj < MMAS_M_PER_LOOP; ++mj) { + uint4 row_0; + row_0.x = acc[0][0].reg(ni * 8 + 0); // Even + row_0.y = acc[0][0].reg(ni * 8 + 4); // Odd + row_0.z = acc[0][0].reg(ni * 8 + 1); // Even + row_0.w = acc[0][0].reg(ni * 8 + 5); // Odd + uint4 row_1; + row_1.x = acc[0][0].reg(ni * 8 + 2); // Even + row_1.y = acc[0][0].reg(ni * 8 + 6); // Odd + row_1.z = acc[0][0].reg(ni * 8 + 3); // Even + row_1.w = acc[0][0].reg(ni * 8 + 7); // Odd + + // Regs_to_rows::extract(acc[mi * MMAS_M_PER_LOOP + mj][ni], row_0, row_1); + + // Each thread of a quad writes 16B per STS -> 64B per store. Account for 2 -> 128B. + int imm_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW * Cta_tile::WARPS_K + (ni / 2) * 128; + int imm_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW * Cta_tile::WARPS_K + (ni / 2) * 128; + + // Store the elements. + fmha::sts(this->smem_write_ + imm_0, row_0); + fmha::sts(this->smem_write_ + imm_1, row_1); + } + // Each thread of a quad writes 16B per STS -> 64B per store. + if (Mma_tile::MMAS_N == 1) { + this->smem_write_ ^= 64; + } else { + assert(false && "Unsupported"); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_o, Cta_tile> + : public Hmma_smem_tile_o< + Hopper_hgmma_fp16_traits, Cta_tile> { + // The traits class. + using Traits = Hopper_hgmma_fp16_traits; + // The base class. + using Base = Hmma_smem_tile_o; + + using Mma_tile = typename Base::Mma_tile; + + using Accumulator = typename Base::Accumulator; + + enum { + LOOPS = Base::LOOPS, + ROW_PACKING = Base::ROW_PACKING, + BYTES_PER_ROW = Base::BYTES_PER_ROW, + }; + + // Ctor. + inline __device__ Smem_tile_o(void* smem, int tidx) : Base(smem, tidx) {} + + // Store the accumulators. + inline __device__ void store(Accumulator const (&acc)[1][1], int mi) { + enum { M_PER_MMA = Mma_tile::M_PER_MMA_PER_CTA }; + +#pragma unroll + for (int ni = 0; ni < Mma_tile::CORES_N; ++ni) { + // The number of MMAs that are stored per loop iteration. + enum { MMAS_M_PER_LOOP = Mma_tile::MMAS_M / LOOPS }; + + static_assert(MMAS_M_PER_LOOP == 1); + // inplace multiples seem to be 1, 3, 1, 7, 1, 3, 1, + auto smem_write = this->smem_write_ ^ (ni * 16); +// Store 1st column of the different MMAs. +#pragma unroll + for (int mj = 0; mj < MMAS_M_PER_LOOP; ++mj) { + // Precompute the immediates to jump between rows. + int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW * Cta_tile::WARPS_K; + int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW * Cta_tile::WARPS_K; + + // Store. + fmha::sts(smem_write + row_0, acc[0][0].reg(ni * 2 + 0)); + fmha::sts(smem_write + row_1, acc[0][0].reg(ni * 2 + 1)); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_o, Cta_tile> + : public Hmma_smem_tile_o< + Hopper_hgmma_fp32_traits, Cta_tile> { + // The traits class. + using Traits = Hopper_hgmma_fp32_traits; + // The base class. + using Base = Hmma_smem_tile_o; + + using Mma_tile = typename Base::Mma_tile; + + using Accumulator = typename Base::Accumulator; + + enum { + LOOPS = Base::LOOPS, + ROW_PACKING = Base::ROW_PACKING, + BYTES_PER_ROW = Base::BYTES_PER_ROW, + }; + + // Ctor. + inline __device__ Smem_tile_o(void* smem, int tidx) : Base(smem, tidx) {} + + // Store the accumulators. + inline __device__ void store(Accumulator const (&acc)[1][1], int mi) { + enum { M_PER_MMA = Mma_tile::M_PER_MMA_PER_CTA }; + +#pragma unroll + for (int ni = 0; ni < Mma_tile::CORES_N; ++ni) { + // The number of MMAs that are stored per loop iteration. + enum { MMAS_M_PER_LOOP = Mma_tile::MMAS_M / LOOPS }; + + static_assert(MMAS_M_PER_LOOP == 1); + // inplace multiples seem to be 1, 3, 1, 7, 1, 3, 1, + auto smem_write = this->smem_write_ ^ (ni * 16); +// Store 1st column of the different MMAs. +#pragma unroll + for (int mj = 0; mj < MMAS_M_PER_LOOP; ++mj) { + // Precompute the immediates to jump between rows. + int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW * Cta_tile::WARPS_K; + int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW * Cta_tile::WARPS_K; + + uint32_t val_0 = float2_to_half2(acc[0][0].elt(2 * ni * Mma_tile::CORES_M + 0), + acc[0][0].elt(2 * ni * Mma_tile::CORES_M + 1)); + + uint32_t val_1 = float2_to_half2(acc[0][0].elt(2 * ni * Mma_tile::CORES_M + 2), + acc[0][0].elt(2 * ni * Mma_tile::CORES_M + 3)); + + // Store. + fmha::sts(smem_write + row_0, val_0); + fmha::sts(smem_write + row_1, val_1); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_o, Cta_tile> + : public Hmma_smem_tile_o< + Hopper_hgmma_bf16_traits, Cta_tile> { + // The traits class. + using Traits = Hopper_hgmma_bf16_traits; + // The base class. + using Base = Hmma_smem_tile_o; + + using Mma_tile = typename Base::Mma_tile; + + using Accumulator = typename Base::Accumulator; + + enum { + LOOPS = Base::LOOPS, + ROW_PACKING = Base::ROW_PACKING, + BYTES_PER_ROW = Base::BYTES_PER_ROW, + }; + + // Ctor. + inline __device__ Smem_tile_o(void* smem, int tidx) : Base(smem, tidx) {} + + // Convert fp32 to bf16, and store the accumulators. + inline __device__ void store(Accumulator const (&acc)[1][1], int mi) { + enum { M_PER_MMA = Mma_tile::M_PER_MMA_PER_CTA }; + + static_assert(Mma_tile::CORES_M == 2); + +#pragma unroll + for (int ni = 0; ni < Mma_tile::CORES_N; ++ni) { + // The number of MMAs that are stored per loop iteration. + enum { MMAS_M_PER_LOOP = Mma_tile::MMAS_M / LOOPS }; + + static_assert(MMAS_M_PER_LOOP == 1); + // inplace multiples seem to be 1, 3, 1, 7, 1, 3, 1, + auto smem_write = this->smem_write_ ^ (ni * 16); +// Store 1st column of the different MMAs. +#pragma unroll + for (int mj = 0; mj < MMAS_M_PER_LOOP; ++mj) { + // Precompute the immediates to jump between rows. + int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW * Cta_tile::WARPS_K; + int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW * Cta_tile::WARPS_K; + + uint32_t val_0 = float2_to_bf16_x2(acc[0][0].elt(2 * ni * Mma_tile::CORES_M + 0), + acc[0][0].elt(2 * ni * Mma_tile::CORES_M + 1)); + + uint32_t val_1 = float2_to_bf16_x2(acc[0][0].elt(2 * ni * Mma_tile::CORES_M + 2), + acc[0][0].elt(2 * ni * Mma_tile::CORES_M + 3)); + + // Store. + fmha::sts(smem_write + row_0, val_0); + fmha::sts(smem_write + row_1, val_1); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_o, + Cta_tile> + : public Smem_tile_o_gmma_32bit_8bit< + Hopper_qgmma_e4m3_fp32_traits, Cta_tile> { + // The traits class. + using Traits = Hopper_qgmma_e4m3_fp32_traits; + // The base class. + using Base = Smem_tile_o_gmma_32bit_8bit; + + // Ctor. + inline __device__ Smem_tile_o(void* smem, int tidx) : Base(smem, tidx) {} +}; + +template +struct Smem_tile_o, + Cta_tile> + : public Smem_tile_o_gmma_32bit_8bit< + Hopper_igmma_int8_int32_traits, Cta_tile> { + // The traits class. + using Traits = Hopper_igmma_int8_int32_traits; + // The base class. + using Base = Smem_tile_o_gmma_32bit_8bit; + + // Ctor. + inline __device__ Smem_tile_o(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/hopper/tma_descriptor.h b/csrc/fmha_v2/fmha/hopper/tma_descriptor.h new file mode 100644 index 0000000000..22071f3585 --- /dev/null +++ b/csrc/fmha_v2/fmha/hopper/tma_descriptor.h @@ -0,0 +1,348 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once +#include + +namespace fmha { + +// manage TMA descriptor host code. +// allocate, deallocate and manipulate tma desc in the host +// copy the tma descriptor from host code to device code +// Multiple TMA desc, one desc per batch. +// Device desc ptr should be allocated outside the class and reused +template < + // number of dimensions. + int NUM_DIMS> +class Multiple_tma_descriptor { + public: + // ctor + Multiple_tma_descriptor(int batch_size_) : batch_size(batch_size_) { + if (batch_size > 0) { + // allocate host memory + desc_ptr_h = new cudaTmaDesc[batch_size]; + // make sure all bit fields are zeros. + memset(desc_ptr_h, 0, sizeof(cudaTmaDesc) * batch_size); + } + } + + // ctor + Multiple_tma_descriptor() = default; + + // destructor. + ~Multiple_tma_descriptor() { + if (batch_size > 0) { + // deallocate host memory + delete[] desc_ptr_h; + } + } + + // set the desctriptor. + int set_tma_desctriptor( + // ptr to gmem + void const* gmem_ptr, + // format is really data_type in TMA terminology. + cudaTmaDescFormat format, + // interleave mode. + cudaTmaDescInterleave interleave, + // swizzle mode. + cudaTmaDescSwizzle swizzle, + // L2 sector promotion. + cudaTmaDescPromotion promotion, uint32_t const (&tensor_size_array)[NUM_DIMS], + uint64_t const (&tensor_stride_array)[NUM_DIMS - 1], + uint32_t const (&traversal_stride_array)[NUM_DIMS], + uint32_t const (&box_size_array)[NUM_DIMS], + // OOB fill mode. + uint32_t fill_oob, + // FP32 to TF32 conversion. + uint32_t round_to_tf32, + // index to desc. + int batch_idx) { + set_tensor_common_0(&desc_ptr_h[batch_idx], reinterpret_cast(gmem_ptr)); + set_tensor_common_1(&desc_ptr_h[batch_idx], TILED, NUM_DIMS, format, interleave, swizzle, + fill_oob, round_to_tf32, promotion); + + set_tensor_stride(&desc_ptr_h[batch_idx], tensor_stride_array); + set_tensor_size(&desc_ptr_h[batch_idx], tensor_size_array); + + set_traversal_stride_tiled(&desc_ptr_h[batch_idx], traversal_stride_array); + + set_box_size(&desc_ptr_h[batch_idx], box_size_array); + return 0; + } + + // set the desctriptor. + int set_tma_desctriptor( + // ptr to gmem + void const* gmem_ptr, + // format is really data_type in TMA terminology. + cudaTmaDescFormat format, + // interleave mode. + cudaTmaDescInterleave interleave, + // swizzle mode. + cudaTmaDescSwizzle swizzle, + // L2 sector promotion. + cudaTmaDescPromotion promotion, uint32_t const (&tensor_size_array)[NUM_DIMS], + uint64_t const (&tensor_stride_array)[NUM_DIMS - 1], + uint32_t const (&traversal_stride_array)[NUM_DIMS], + uint32_t const (&box_size_array)[NUM_DIMS], + // OOB fill mode. + uint32_t fill_oob, + // FP32 to TF32 conversion. + uint32_t round_to_tf32, + // index to desc. + cudaTmaDesc* desc_ptr = nullptr) { + set_tensor_common_0(desc_ptr, reinterpret_cast(gmem_ptr)); + set_tensor_common_1(desc_ptr, TILED, NUM_DIMS, format, interleave, swizzle, fill_oob, + round_to_tf32, promotion); + + set_tensor_stride(desc_ptr, tensor_stride_array); + set_tensor_size(desc_ptr, tensor_size_array); + + set_traversal_stride_tiled(desc_ptr, traversal_stride_array); + + set_box_size(desc_ptr, box_size_array); + return 0; + } + + // copy the desc to device memory + void copy_to_device(void* desc_ptr_d_, cudaStream_t stream = 0) { + FMHA_CHECK_CUDA(cudaMemcpy(desc_ptr_d_, desc_ptr_h, TMA_DESC_SIZE_IN_BYTE * batch_size, + cudaMemcpyHostToDevice)); + } + + // get desc in host + cudaTmaDesc get_desc_in_host(int batch_idx) const { return desc_ptr_h[batch_idx]; } + + private: + void set_tensor_common_0(cudaTmaDesc* p_desc, uint64_t addr) { + cudaTmaDescTiled* desc = reinterpret_cast(p_desc); + desc->tensor_common0 = 0; + desc->tensor_common0 |= (addr); + } + + void set_tensor_common_1(cudaTmaDesc* p_desc, cudaTmaDescType desc_type, uint32_t dims, + cudaTmaDescFormat format, cudaTmaDescInterleave interleave, + cudaTmaDescSwizzle swizzle, uint32_t fill, uint32_t f32_to_tf32, + cudaTmaDescPromotion promotion) { + cudaTmaDescTiled* desc = reinterpret_cast(p_desc); + + desc->tensor_common1 = 0; + desc->tensor_common1 |= desc_type == TILED ? 0x0 : 0x1; + + constexpr uint32_t VERSION_SHIFT = 1; + constexpr uint32_t VERSION_BITS = 3; + desc->tensor_common1 |= (1u << VERSION_SHIFT); + + constexpr uint32_t DIM_BITS = 3; + constexpr uint32_t DIM_SHIFT = VERSION_SHIFT + VERSION_BITS; + constexpr uint32_t DIM_MASK = (1u << DIM_BITS) - 1; + desc->tensor_common1 |= ((dims - 1) & DIM_MASK) << DIM_SHIFT; + + constexpr uint32_t FORMAT_BITS = 4; + constexpr uint32_t FORMAT_SHIFT = DIM_SHIFT + DIM_BITS; + constexpr uint32_t FORMAT_MASK = (1u << FORMAT_BITS) - 1; + desc->tensor_common1 |= (static_cast(format) & FORMAT_MASK) << FORMAT_SHIFT; + + constexpr uint32_t INTERLEAVE_BITS = 2; + constexpr uint32_t INTERLEAVE_SHIFT = FORMAT_SHIFT + FORMAT_BITS; + constexpr uint32_t INTERLEAVE_MASK = (1u << INTERLEAVE_BITS) - 1; + desc->tensor_common1 |= (static_cast(interleave) & INTERLEAVE_MASK) + << INTERLEAVE_SHIFT; + + constexpr uint32_t SWIZZLE_BITS = 2; + constexpr uint32_t SWIZZLE_SHIFT = INTERLEAVE_SHIFT + INTERLEAVE_BITS; + constexpr uint32_t SWIZZLE_MASK = (1u << SWIZZLE_BITS) - 1; + desc->tensor_common1 |= (static_cast(swizzle) & SWIZZLE_MASK) << SWIZZLE_SHIFT; + + constexpr uint32_t FILL_BITS = 1; + constexpr uint32_t FILL_SHIFT = SWIZZLE_SHIFT + SWIZZLE_BITS; + constexpr uint32_t FILL_MASK = (1u << FILL_BITS) - 1; + desc->tensor_common1 |= (static_cast(fill) & FILL_MASK) << FILL_SHIFT; + + constexpr uint32_t F32_TO_TF32_BITS = 1; + constexpr uint32_t F32_TO_TF32_SHIFT = FILL_SHIFT + FILL_BITS; + constexpr uint32_t F32_TO_TF32_MASK = (1u << F32_TO_TF32_BITS) - 1; + desc->tensor_common1 |= (static_cast(f32_to_tf32) & F32_TO_TF32_MASK) + << F32_TO_TF32_SHIFT; + + constexpr uint32_t PROMOTION_BITS = 2; + constexpr uint32_t PROMOTION_SHIFT = F32_TO_TF32_SHIFT + F32_TO_TF32_BITS; + constexpr uint32_t PROMOTION_MASK = (1u << PROMOTION_BITS) - 1; + desc->tensor_common1 |= (static_cast(promotion) & PROMOTION_MASK) << PROMOTION_SHIFT; + } + + // note that tensor stride has 1 less dim. + void set_tensor_stride(cudaTmaDesc* p_desc, uint64_t const (&tensor_stride_array)[NUM_DIMS - 1]) { + cudaTmaDescTiled* desc = reinterpret_cast(p_desc); + + constexpr uint32_t TENSOR_STRIDE_UPPER_BITS = 4; + constexpr uint32_t TENSOR_STRIDE_UPPER_MASK = (1u << TENSOR_STRIDE_UPPER_BITS) - 1; + + for (uint32_t i = 0; i < NUM_DIMS - 1; i++) { + desc->tensor_stride_lower[i] = 0u; + uint64_t tensor_stride_lower_64b = (tensor_stride_array[i] >> 4) & 0xFFFFFFFFlu; + desc->tensor_stride_lower[i] = static_cast(tensor_stride_lower_64b); + } + desc->tensor_stride_upper = 0u; + + for (uint32_t i = 0; i < NUM_DIMS - 1; i++) { + uint64_t tensor_stride_temp = tensor_stride_array[i]; + tensor_stride_temp = tensor_stride_temp >> 4; + uint64_t tensor_stride_upper = tensor_stride_temp >> 32; + uint32_t tensor_stride_upper_32b = static_cast(tensor_stride_upper); + desc->tensor_stride_upper |= + ((tensor_stride_upper_32b & TENSOR_STRIDE_UPPER_MASK) << (i * TENSOR_STRIDE_UPPER_BITS)); + } + } + + void set_tensor_size(cudaTmaDesc* p_desc, uint32_t const (&tensor_size_array)[NUM_DIMS]) { + cudaTmaDescTiled* desc = reinterpret_cast(p_desc); + for (uint32_t dim = 0; dim < NUM_DIMS; dim++) { + desc->tensor_size[dim] = tensor_size_array[dim] - 1; + } + } + + void set_traversal_stride_tiled(cudaTmaDesc* p_desc, + uint32_t const (&traversal_stride_array)[NUM_DIMS]) { + cudaTmaDescTiled* desc = reinterpret_cast(p_desc); + + desc->traversal_stride_box_0 = 0; + + constexpr uint32_t TRAVERSAL_STRIDE_BITS = 3; + constexpr uint32_t TRAVERSAL_STRIDE_MASK = (1u << TRAVERSAL_STRIDE_BITS) - 1; + + for (uint32_t dim = 0; dim < NUM_DIMS; dim++) { + uint32_t traversal_stride = traversal_stride_array[dim] - 1; + traversal_stride = (traversal_stride & TRAVERSAL_STRIDE_MASK) + << (dim * TRAVERSAL_STRIDE_BITS); + desc->traversal_stride_box_0 |= traversal_stride; + } + } + + void set_box_size(cudaTmaDesc* p_desc, uint32_t const (&box_size_array)[NUM_DIMS]) { + cudaTmaDescTiled* desc = reinterpret_cast(p_desc); + + desc->box_size_end = 0; + + constexpr uint32_t BOX_SIZE_BITS = 8; + constexpr uint32_t BOX_SIZE_MASK = (1 << BOX_SIZE_BITS) - 1; + + if (NUM_DIMS > 1) { + uint32_t box_size_0 = box_size_array[0] - 1; + box_size_0 = box_size_0 & BOX_SIZE_MASK; + box_size_0 = box_size_0 << 24; + desc->traversal_stride_box_0 |= box_size_0; + } + + for (uint32_t dim = 1; dim < NUM_DIMS; dim++) { + uint32_t box_size = box_size_array[dim] - 1; + box_size = box_size & BOX_SIZE_MASK; + box_size = box_size << ((dim - 1) * BOX_SIZE_BITS); + desc->box_size_end |= box_size; + } + } + + void set_traversal_stride_im2col(cudaTmaDesc* p_desc, uint32_t* p_traversal_stride, + uint32_t dims) { + cudaTmaDescIm2Col* desc = reinterpret_cast(p_desc); + + desc->traversal_stride_range_c = 0; + + constexpr uint32_t TRAVERSAL_STRIDE_BITS = 3; + constexpr uint32_t TRAVERSAL_STRIDE_MASK = (1u << (TRAVERSAL_STRIDE_BITS + 1)) - 1; + + for (uint32_t dim = 0; dim < dims; dim++) { + uint32_t traversal_stride = p_traversal_stride[dim] - 1; + traversal_stride = (traversal_stride & TRAVERSAL_STRIDE_MASK) + << (dim * TRAVERSAL_STRIDE_BITS); + desc->traversal_stride_range_c |= traversal_stride; + } + } + + void set_range_c(cudaTmaDesc* p_desc, uint32_t range_c) { + cudaTmaDescIm2Col* desc = reinterpret_cast(p_desc); + + constexpr uint32_t RANGE_C_BITS = 8; + constexpr uint32_t RANGE_C_MASK = (1u << RANGE_C_BITS) - 1; + + range_c = range_c & RANGE_C_MASK; + desc->traversal_stride_range_c |= ((range_c - 1) << 24); + } + + void set_box_corner_dhw(cudaTmaDesc* p_desc, uint32_t* p_base_corner, uint32_t* p_far_corner, + uint32_t dims) { + cudaTmaDescIm2Col* desc = reinterpret_cast(p_desc); + + desc->box_corner_dhw = 0; + + uint32_t box_base_corner = 0, box_far_corner = 0; + uint32_t box_corner_dhw = 0; + + if (dims == 3) { + constexpr uint32_t BOX_CORNER_BITS = 16; + constexpr uint32_t BOX_CORNER_MASK = (1u << BOX_CORNER_BITS) - 1; + + box_base_corner = p_base_corner[0] & BOX_CORNER_MASK; + box_far_corner = p_far_corner[0] & BOX_CORNER_MASK; + } + + if (dims == 4) { + constexpr uint32_t BOX_CORNER_BITS = 8; + constexpr uint32_t BOX_CORNER_MASK = (1u << BOX_CORNER_BITS) - 1; + + box_base_corner = p_base_corner[0] & BOX_CORNER_MASK; + box_base_corner |= ((p_base_corner[1] & BOX_CORNER_MASK) << BOX_CORNER_BITS); + + box_far_corner = p_far_corner[0] & BOX_CORNER_MASK; + box_far_corner |= ((p_far_corner[1] & BOX_CORNER_MASK) << BOX_CORNER_BITS); + } + + if (dims == 5) { + constexpr uint32_t BOX_CORNER_BITS = 5; + constexpr uint32_t BOX_CORNER_MASK = (1u << BOX_CORNER_BITS) - 1; + + box_base_corner = p_base_corner[0] & BOX_CORNER_MASK; + box_base_corner |= ((p_base_corner[1] & BOX_CORNER_MASK) << BOX_CORNER_BITS); + box_base_corner |= ((p_base_corner[2] & BOX_CORNER_MASK) << (2 * BOX_CORNER_BITS)); + + box_far_corner = p_far_corner[0] & BOX_CORNER_MASK; + box_far_corner |= ((p_far_corner[1] & BOX_CORNER_MASK) << BOX_CORNER_BITS); + box_far_corner |= ((p_far_corner[2] & BOX_CORNER_MASK) << (2 * BOX_CORNER_BITS)); + } + + box_corner_dhw = box_base_corner; + box_corner_dhw |= (box_far_corner << 16); + + desc->box_corner_dhw = box_corner_dhw; + } + + void set_range_ndhw(cudaTmaDesc* p_desc, uint32_t ndhw) { + cudaTmaDescIm2Col* desc = reinterpret_cast(p_desc); + + desc->range_ndhw = 0; + + constexpr uint32_t RANGE_NDHW_BITS = 10; + constexpr uint32_t RANGE_NDHW_MASK = (1u << RANGE_NDHW_BITS) - 1; + + desc->range_ndhw = ((ndhw - 1) & RANGE_NDHW_MASK); + } + + // The TMA descriptor. Each is of 512 bit. + cudaTmaDesc* desc_ptr_h; + // The TMA descriptor on the device memory. + cudaTmaDesc* desc_ptr_d; + // Number of batches + int batch_size = 0; +}; + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/hopper/tma_types.h b/csrc/fmha_v2/fmha/hopper/tma_types.h new file mode 100644 index 0000000000..4f5460ef64 --- /dev/null +++ b/csrc/fmha_v2/fmha/hopper/tma_types.h @@ -0,0 +1,123 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include + +namespace fmha { + +// TMA desc type. +typedef enum { TILED = 0, IM2COL } cudaTmaDescType; + +// TMA swizzle type. +typedef enum { + SWIZZLE_DISABLED, + SWIZZLE_32B, + SWIZZLE_64B, + SWIZZLE_128B, + SWIZZLE_MAX +} cudaTmaDescSwizzle; + +typedef enum { BARRIER64, BARRIER128 } cudaTmaDescBarrier; + +// TMA interleave type. +typedef enum { + INTERLEAVE_DISABLED, + INTERLEAVE_16B, + INTERLEAVE_32B, + INTERLEAVE_MAX +} cudaTmaDescInterleave; + +// TMA L2 sector promotion. +typedef enum { + PROMOTION_DISABLED = 0, + PROMOTION_64B, + PROMOTION_128B, + PROMOTION_256B +} cudaTmaDescPromotion; + +// TMA data type. +typedef enum { + U8 = 0, + U16, + U32, + S32, + U64, + S64, + F16_RN, + F32_RN, + F32_FTZ_RN, + F64_RN, + BF16_RN, + FORMAT_MAX +} cudaTmaDescFormat; + +// TMA cache control. +typedef enum { + PREFETCH, // Prefetch tma descriptor using global memory address + INVALIDATE, // Invalidate tma descriptor in l2 cache + INVALIDATE_ALL // Invalidate tma descriptor and all elements in l2 cache line +} cudaTmaDescCacheCtrl; + +// TMA OOB fill modes. +typedef enum { TENSOR_ZFILL, TENSOR_CFILL } cudaTmaDescOobFillMode; + +constexpr uint64_t k_max_tensor_size = (1llu << 36); +constexpr uint64_t k_max_tensor_stride = (1llu << 36); +constexpr uint64_t k_max_block_size = 256llu; +constexpr uint64_t k_max_traversal_stride = (1llu << 3); + +constexpr uint64_t k_min_tensor_size = 1llu; +constexpr uint64_t k_min_tensor_stride = 0llu; +constexpr uint64_t k_min_block_size = 1llu; +constexpr uint64_t k_min_traversal_stride = 1llu; + +constexpr uint32_t k_max_cta_id = (1 << 6) - 1; + +// The 512 bit of descriptor for tiled mode. +typedef struct { + uint64_t tensor_common0; + uint32_t tensor_common1; + + uint32_t tensor_stride_lower[4]; //< 36b of 64b with 4B aligned + uint32_t tensor_stride_upper; + uint32_t tensor_size[5]; //< value -1 + uint32_t traversal_stride_box_0; //< packed 3b (-1) + + uint32_t box_size_end; +} cudaTmaDescTiled; + +// The 512 bit of descritptro for im2col mode. +typedef struct { + uint64_t tensor_common0; + uint32_t tensor_common1; + + uint32_t tensor_stride_lower[4]; + uint32_t tensor_stride_upper; + uint32_t tensor_size[5]; + uint32_t traversal_stride_range_c; + + uint32_t box_corner_dhw; + uint32_t range_ndhw; +} cudaTmaDescIm2Col; + +// TMA desc size +constexpr uint32_t TMA_DESC_SIZE_IN_BYTE = 64; + +// TMA desc +typedef struct alignas(64) { + uint64_t data[8]; +} cudaTmaDesc; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/hopper/utils_gmma.h b/csrc/fmha_v2/fmha/hopper/utils_gmma.h new file mode 100644 index 0000000000..cc070be7de --- /dev/null +++ b/csrc/fmha_v2/fmha/hopper/utils_gmma.h @@ -0,0 +1,18 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include +#include +#include +#include diff --git a/csrc/fmha_v2/fmha/hopper/utils_hgmma.h b/csrc/fmha_v2/fmha/hopper/utils_hgmma.h new file mode 100644 index 0000000000..5112317228 --- /dev/null +++ b/csrc/fmha_v2/fmha/hopper/utils_hgmma.h @@ -0,0 +1,874 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// GMMAs with fp16 Accumulator +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_fp16 {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x8x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_fp16<8, TA, TB> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[2]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_a = TA ? 1 : 0; + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16\n" + "{\n" + " %0, %1\n" + "}, %2, %3, 1, 1, 1, %4, %5;\n" + + : "+r"(acc[0]), "+r"(acc[1]) + : "l"(desc_a), "l"(desc_b), "n"(trans_a), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x32x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_fp16<32, TA, TB> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[8]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_a = TA ? 1 : 0; + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f16.f16.f16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7 \n" + "},\n" + " %8, %9, 1, 1, 1, %10, %11;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]) + : "l"(desc_a), "l"(desc_b), "n"(trans_a), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x64x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_fp16<64, TA, TB> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[16]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_a = TA ? 1 : 0; + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f16.f16.f16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15 \n" + "},\n" + " %16, %17, 1, 1, 1, %18, %19;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]) + : "l"(desc_a), "l"(desc_b), "n"(trans_a), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x128x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_fp16<128, TA, TB> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[32]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_a = TA ? 1 : 0; + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f16.f16.f16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31 \n" + "},\n" + " %32, %33, 1, 1, 1, %34, %35;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]) + : "l"(desc_a), "l"(desc_b), "n"(trans_a), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x192x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_fp16<192, TA, TB> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[48]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_a = TA ? 1 : 0; + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k16.f16.f16.f16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47 \n" + "},\n" + " %48, %49, 1, 1, 1, %50, %51;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]) + : "l"(desc_a), "l"(desc_b), "n"(trans_a), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x256x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_fp16<256, TA, TB> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[64]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_a = TA ? 1 : 0; + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f16.f16.f16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63 \n" + "},\n" + " %64, %65, 1, 1, 1, %66, %67;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]) + : "l"(desc_a), "l"(desc_b), "n"(trans_a), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void hgmma_fp16(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[N / 4]) { + Hgmma_fp16::mma(desc_a, desc_b, acc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// GMMAs with fp32 Accumulator +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_fp32 {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x8x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_fp32<8, TA, TB> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[4]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_a = TA ? 1 : 0; + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3}, %4, %5, 1, 1, 1, %6, %7;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]) + : "l"(desc_a), "l"(desc_b), "n"(trans_a), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x64x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_fp32<64, TA, TB> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[32]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_a = TA ? 1 : 0; + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31 \n" + "},\n" + " %32, %33, 1, 1, 1, %34, %35;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]) + : "l"(desc_a), "l"(desc_b), "n"(trans_a), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x128x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_fp32<128, TA, TB> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[64]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_a = TA ? 1 : 0; + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63 \n" + "},\n" + " %64, %65, 1, 1, 1, %66, %67;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]) + : "l"(desc_a), "l"(desc_b), "n"(trans_a), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x192x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_fp32<192, TA, TB> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[96]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_a = TA ? 1 : 0; + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k16.f32.f16.f16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95 \n" + "},\n" + " %96, %97, 1, 1, 1, %98, %99;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]) + : "l"(desc_a), "l"(desc_b), "n"(trans_a), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x256x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_fp32<256, TA, TB> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[128]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_a = TA ? 1 : 0; + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95,\n" + " %96, %97, %98, %99, %100, %101, %102, %103,\n" + " %104, %105, %106, %107, %108, %109, %110, %111,\n" + " %112, %113, %114, %115, %116, %117, %118, %119,\n" + " %120, %121, %122, %123, %124, %125, %126, %127 \n" + "},\n" + " %128, %129, 1, 1, 1, %130, %131;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]), + "+r"(acc[96]), "+r"(acc[97]), "+r"(acc[98]), "+r"(acc[99]), "+r"(acc[100]), + "+r"(acc[101]), "+r"(acc[102]), "+r"(acc[103]), "+r"(acc[104]), "+r"(acc[105]), + "+r"(acc[106]), "+r"(acc[107]), "+r"(acc[108]), "+r"(acc[109]), "+r"(acc[110]), + "+r"(acc[111]), "+r"(acc[112]), "+r"(acc[113]), "+r"(acc[114]), "+r"(acc[115]), + "+r"(acc[116]), "+r"(acc[117]), "+r"(acc[118]), "+r"(acc[119]), "+r"(acc[120]), + "+r"(acc[121]), "+r"(acc[122]), "+r"(acc[123]), "+r"(acc[124]), "+r"(acc[125]), + "+r"(acc[126]), "+r"(acc[127]) + : "l"(desc_a), "l"(desc_b), "n"(trans_a), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void hgmma_fp32(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[N / 2]) { + Hgmma_fp32::mma(desc_a, desc_b, acc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// GMMAs with fp16 Accumulator, where A is coming from RF +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_rfa_fp16 {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x8x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_rfa_fp16<8, TB> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[2]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16 " + "{%0, %1}, {%2, %3, %4, %5}, %6, 1, 1, 1, %7;\n" + + : "+r"(acc[0]), "+r"(acc[1]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_a), "l"(desc_b), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x16x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_rfa_fp16<16, TB> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[4]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f16.f16.f16 " + "{ %0, %1, %2, %3 },\n" + "{ %4, %5, %6, %7 }, %8, 1, 1, 1, %9;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x32x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_rfa_fp16<32, TB> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[8]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f16.f16.f16 " + "{ %0, %1, %2, %3, %4, %5, %6, %7 },\n" + "{ %8, %9, %10, %11 }, %12, 1, 1, 1, %13;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x64x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_rfa_fp16<64, TB> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[16]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f16.f16.f16 " + "{" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15 \n" + "},\n" + "{ %16, %17, %18, %19 }, %20, 1, 1, 1, %21;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x128x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_rfa_fp16<128, TB> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[32]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f16.f16.f16 " + "{" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31 \n" + "},\n" + "{ %32, %33, %34, %35 }, %36, 1, 1, 1, %37;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x192x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_rfa_fp16<192, TB> { + static inline __device__ void mma(const uint32_t (&a)[4], uint64_t desc_b, uint32_t (&acc)[48]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k16.f16.f16.f16 " + "{" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47 \n" + "},\n" + "{ %48, %49, %50, %51 }, %52, 1, 1, 1, %53;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x256x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_rfa_fp16<256, TB> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[64]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f16.f16.f16 " + "{" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63 \n" + "},\n" + "{ %64, %65, %66, %67 }, %68, 1, 1, 1, %69;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void hgmma_rfa_fp16(uint32_t const (&a)[4], uint64_t desc_b, + uint32_t (&acc)[N / 4]) { + Hgmma_rfa_fp16::mma(a, desc_b, acc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// GMMAs with fp32 Accumulator, where A is coming from RF +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_rfa_fp32 {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x8x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_rfa_fp32<8, TB> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[4]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{\n" + " %0, %1, %2, %3\n" + "}\n," + "{ %4, %5, %6, %7 }, %8, 1, 1, 1, %9;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x32x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_rfa_fp32<32, TB> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[16]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15 \n" + "},\n" + "{ %16, %17, %18, %19 }, %20, 1, 1, 1, %21;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x64x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_rfa_fp32<64, TB> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[32]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31 \n" + "},\n" + "{ %32, %33, %34, %35 }, %36, 1, 1, 1, %37;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x128x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_rfa_fp32<128, TB> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[64]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63 \n" + "},\n" + "{ %64, %65, %66, %67 }, %68, 1, 1, 1, %69;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x192x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_rfa_fp32<192, TB> { + static inline __device__ void mma(const uint32_t (&a)[4], uint64_t desc_b, uint32_t (&acc)[96]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k16.f32.f16.f16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95 \n" + "},\n" + "{ %96, %97, %98, %99 }, %100, 1, 1, 1, %101;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x256x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_rfa_fp32<256, TB> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[128]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95,\n" + " %96, %97, %98, %99, %100, %101, %102, %103,\n" + " %104, %105, %106, %107, %108, %109, %110, %111,\n" + " %112, %113, %114, %115, %116, %117, %118, %119,\n" + " %120, %121, %122, %123, %124, %125, %126, %127 \n" + "},\n" + "{ %128, %129, %130, %131 }, %132, 1, 1, 1, %133;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]), + "+r"(acc[96]), "+r"(acc[97]), "+r"(acc[98]), "+r"(acc[99]), "+r"(acc[100]), + "+r"(acc[101]), "+r"(acc[102]), "+r"(acc[103]), "+r"(acc[104]), "+r"(acc[105]), + "+r"(acc[106]), "+r"(acc[107]), "+r"(acc[108]), "+r"(acc[109]), "+r"(acc[110]), + "+r"(acc[111]), "+r"(acc[112]), "+r"(acc[113]), "+r"(acc[114]), "+r"(acc[115]), + "+r"(acc[116]), "+r"(acc[117]), "+r"(acc[118]), "+r"(acc[119]), "+r"(acc[120]), + "+r"(acc[121]), "+r"(acc[122]), "+r"(acc[123]), "+r"(acc[124]), "+r"(acc[125]), + "+r"(acc[126]), "+r"(acc[127]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void hgmma_rfa_fp32(uint32_t const (&a)[4], uint64_t desc_b, + uint32_t (&acc)[N / 2]) { + Hgmma_rfa_fp32::mma(a, desc_b, acc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/hopper/utils_hgmma_bf16.h b/csrc/fmha_v2/fmha/hopper/utils_hgmma_bf16.h new file mode 100644 index 0000000000..7b17b508bb --- /dev/null +++ b/csrc/fmha_v2/fmha/hopper/utils_hgmma_bf16.h @@ -0,0 +1,475 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// BF16 GMMAs with FP32 Accumulator +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_bf16 {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x8x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_bf16<8, TA, TB> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[4]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_a = TA ? 1 : 0; + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3}, %4, %5, 1, 1, 1, %6, %7;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]) + : "l"(desc_a), "l"(desc_b), "n"(trans_a), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x64x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_bf16<64, TA, TB> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[32]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_a = TA ? 1 : 0; + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31 \n" + "},\n" + " %32, %33, 1, 1, 1, %34, %35;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]) + : "l"(desc_a), "l"(desc_b), "n"(trans_a), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x128x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_bf16<128, TA, TB> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[64]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_a = TA ? 1 : 0; + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63 \n" + "},\n" + " %64, %65, 1, 1, 1, %66, %67;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]) + : "l"(desc_a), "l"(desc_b), "n"(trans_a), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x192x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_bf16<192, TA, TB> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[96]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_a = TA ? 1 : 0; + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k16.f32.bf16.bf16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95 \n" + "},\n" + " %96, %97, 1, 1, 1, %98, %99;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]) + : "l"(desc_a), "l"(desc_b), "n"(trans_a), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x256x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_bf16<256, TA, TB> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[128]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_a = TA ? 1 : 0; + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95,\n" + " %96, %97, %98, %99, %100, %101, %102, %103,\n" + " %104, %105, %106, %107, %108, %109, %110, %111,\n" + " %112, %113, %114, %115, %116, %117, %118, %119,\n" + " %120, %121, %122, %123, %124, %125, %126, %127 \n" + "},\n" + " %128, %129, 1, 1, 1, %130, %131;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]), + "+r"(acc[96]), "+r"(acc[97]), "+r"(acc[98]), "+r"(acc[99]), "+r"(acc[100]), + "+r"(acc[101]), "+r"(acc[102]), "+r"(acc[103]), "+r"(acc[104]), "+r"(acc[105]), + "+r"(acc[106]), "+r"(acc[107]), "+r"(acc[108]), "+r"(acc[109]), "+r"(acc[110]), + "+r"(acc[111]), "+r"(acc[112]), "+r"(acc[113]), "+r"(acc[114]), "+r"(acc[115]), + "+r"(acc[116]), "+r"(acc[117]), "+r"(acc[118]), "+r"(acc[119]), "+r"(acc[120]), + "+r"(acc[121]), "+r"(acc[122]), "+r"(acc[123]), "+r"(acc[124]), "+r"(acc[125]), + "+r"(acc[126]), "+r"(acc[127]) + : "l"(desc_a), "l"(desc_b), "n"(trans_a), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void hgmma_bf16(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[N / 2]) { + Hgmma_bf16::mma(desc_a, desc_b, acc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// BF16 GMMAs with FP32 Accumulator, where A is coming from RF +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_rfa_bf16 {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x8x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_rfa_bf16<8, TB> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[4]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{\n" + " %0, %1, %2, %3\n" + "}\n," + "{ %4, %5, %6, %7 }, %8, 1, 1, 1, %9;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x32x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_rfa_bf16<32, TB> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[16]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15 \n" + "},\n" + "{ %16, %17, %18, %19 }, %20, 1, 1, 1, %21;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x64x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_rfa_bf16<64, TB> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[32]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31 \n" + "},\n" + "{ %32, %33, %34, %35 }, %36, 1, 1, 1, %37;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x128x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_rfa_bf16<128, TB> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[64]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63 \n" + "},\n" + "{ %64, %65, %66, %67 }, %68, 1, 1, 1, %69;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x192x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_rfa_bf16<192, TB> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[96]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k16.f32.bf16.bf16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95 \n" + "},\n" + "{ %96, %97, %98, %99 }, %100, 1, 1, 1, %101;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x256x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_rfa_bf16<256, TB> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[128]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95,\n" + " %96, %97, %98, %99, %100, %101, %102, %103,\n" + " %104, %105, %106, %107, %108, %109, %110, %111,\n" + " %112, %113, %114, %115, %116, %117, %118, %119,\n" + " %120, %121, %122, %123, %124, %125, %126, %127 \n" + "},\n" + "{ %128, %129, %130, %131 }, %132, 1, 1, 1, %133;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]), + "+r"(acc[96]), "+r"(acc[97]), "+r"(acc[98]), "+r"(acc[99]), "+r"(acc[100]), + "+r"(acc[101]), "+r"(acc[102]), "+r"(acc[103]), "+r"(acc[104]), "+r"(acc[105]), + "+r"(acc[106]), "+r"(acc[107]), "+r"(acc[108]), "+r"(acc[109]), "+r"(acc[110]), + "+r"(acc[111]), "+r"(acc[112]), "+r"(acc[113]), "+r"(acc[114]), "+r"(acc[115]), + "+r"(acc[116]), "+r"(acc[117]), "+r"(acc[118]), "+r"(acc[119]), "+r"(acc[120]), + "+r"(acc[121]), "+r"(acc[122]), "+r"(acc[123]), "+r"(acc[124]), "+r"(acc[125]), + "+r"(acc[126]), "+r"(acc[127]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void hgmma_rfa_bf16(uint32_t const (&a)[4], uint64_t desc_b, + uint32_t (&acc)[N / 2]) { + Hgmma_rfa_bf16::mma(a, desc_b, acc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/hopper/utils_igmma.h b/csrc/fmha_v2/fmha/hopper/utils_igmma.h new file mode 100644 index 0000000000..fcced80616 --- /dev/null +++ b/csrc/fmha_v2/fmha/hopper/utils_igmma.h @@ -0,0 +1,396 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// IGMMA 64xNx32 TN with int32 Accumulator with A and B from SMEM +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Igmma_int8_int32 {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x64x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Igmma_int8_int32<64> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[32]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31 \n" + "},\n" + " %32, %33, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x128x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Igmma_int8_int32<128> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[64]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63 \n" + "},\n" + " %64, %65, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x192x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Igmma_int8_int32<192> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[96]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95 \n" + "},\n" + " %96, %97, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x256x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Igmma_int8_int32<256> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[128]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95,\n" + " %96, %97, %98, %99, %100, %101, %102, %103,\n" + " %104, %105, %106, %107, %108, %109, %110, %111,\n" + " %112, %113, %114, %115, %116, %117, %118, %119,\n" + " %120, %121, %122, %123, %124, %125, %126, %127 \n" + "},\n" + " %128, %129, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]), + "+r"(acc[96]), "+r"(acc[97]), "+r"(acc[98]), "+r"(acc[99]), "+r"(acc[100]), + "+r"(acc[101]), "+r"(acc[102]), "+r"(acc[103]), "+r"(acc[104]), "+r"(acc[105]), + "+r"(acc[106]), "+r"(acc[107]), "+r"(acc[108]), "+r"(acc[109]), "+r"(acc[110]), + "+r"(acc[111]), "+r"(acc[112]), "+r"(acc[113]), "+r"(acc[114]), "+r"(acc[115]), + "+r"(acc[116]), "+r"(acc[117]), "+r"(acc[118]), "+r"(acc[119]), "+r"(acc[120]), + "+r"(acc[121]), "+r"(acc[122]), "+r"(acc[123]), "+r"(acc[124]), "+r"(acc[125]), + "+r"(acc[126]), "+r"(acc[127]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void igmma_int8_int32(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[N / 2]) { + Igmma_int8_int32::mma(desc_a, desc_b, acc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// IGMMA 64xNx32 TN with int32 Accumulator with A from RF and B from SMEM. +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Igmma_rfa_int8_int32 {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x64x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Igmma_rfa_int8_int32<64> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[32]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31 \n" + "},\n" + "{ %32, %33, %34, %35 }, %36, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x128x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Igmma_rfa_int8_int32<128> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[64]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63 \n" + "},\n" + "{ %64, %65, %66, %67 }, %68, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x192x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Igmma_rfa_int8_int32<192> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[96]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95 \n" + "},\n" + "{ %96, %97, %98, %99 }, %100, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x256x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Igmma_rfa_int8_int32<256> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[128]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95,\n" + " %96, %97, %98, %99, %100, %101, %102, %103,\n" + " %104, %105, %106, %107, %108, %109, %110, %111,\n" + " %112, %113, %114, %115, %116, %117, %118, %119,\n" + " %120, %121, %122, %123, %124, %125, %126, %127 \n" + "},\n" + "{ %128, %129, %130, %131 }, %132, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]), + "+r"(acc[96]), "+r"(acc[97]), "+r"(acc[98]), "+r"(acc[99]), "+r"(acc[100]), + "+r"(acc[101]), "+r"(acc[102]), "+r"(acc[103]), "+r"(acc[104]), "+r"(acc[105]), + "+r"(acc[106]), "+r"(acc[107]), "+r"(acc[108]), "+r"(acc[109]), "+r"(acc[110]), + "+r"(acc[111]), "+r"(acc[112]), "+r"(acc[113]), "+r"(acc[114]), "+r"(acc[115]), + "+r"(acc[116]), "+r"(acc[117]), "+r"(acc[118]), "+r"(acc[119]), "+r"(acc[120]), + "+r"(acc[121]), "+r"(acc[122]), "+r"(acc[123]), "+r"(acc[124]), "+r"(acc[125]), + "+r"(acc[126]), "+r"(acc[127]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void igmma_rfa_int8_int32(uint32_t const (&a)[4], uint64_t desc_b, + uint32_t (&acc)[N / 2]) { + Igmma_rfa_int8_int32::mma(a, desc_b, acc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/hopper/utils_qgmma.h b/csrc/fmha_v2/fmha/hopper/utils_qgmma.h new file mode 100644 index 0000000000..28571b15b9 --- /dev/null +++ b/csrc/fmha_v2/fmha/hopper/utils_qgmma.h @@ -0,0 +1,2089 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// QGMMA 64xNx32 TN with int32 Accumulator with A and B from SMEM +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Qgmma_e4m3_e4m3_fp32 {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x32x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e4m3_e4m3_fp32<32> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[16]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15\n" + "},\n" + " %16, %17, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x64x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e4m3_e4m3_fp32<64> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[32]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31 \n" + "},\n" + " %32, %33, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x128x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e4m3_e4m3_fp32<128> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[64]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63 \n" + "},\n" + " %64, %65, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x192x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e4m3_e4m3_fp32<192> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[96]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95 \n" + "},\n" + " %96, %97, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x256x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e4m3_e4m3_fp32<256> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[128]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95,\n" + " %96, %97, %98, %99, %100, %101, %102, %103,\n" + " %104, %105, %106, %107, %108, %109, %110, %111,\n" + " %112, %113, %114, %115, %116, %117, %118, %119,\n" + " %120, %121, %122, %123, %124, %125, %126, %127 \n" + "},\n" + " %128, %129, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]), + "+r"(acc[96]), "+r"(acc[97]), "+r"(acc[98]), "+r"(acc[99]), "+r"(acc[100]), + "+r"(acc[101]), "+r"(acc[102]), "+r"(acc[103]), "+r"(acc[104]), "+r"(acc[105]), + "+r"(acc[106]), "+r"(acc[107]), "+r"(acc[108]), "+r"(acc[109]), "+r"(acc[110]), + "+r"(acc[111]), "+r"(acc[112]), "+r"(acc[113]), "+r"(acc[114]), "+r"(acc[115]), + "+r"(acc[116]), "+r"(acc[117]), "+r"(acc[118]), "+r"(acc[119]), "+r"(acc[120]), + "+r"(acc[121]), "+r"(acc[122]), "+r"(acc[123]), "+r"(acc[124]), "+r"(acc[125]), + "+r"(acc[126]), "+r"(acc[127]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void qgmma_e4m3_e4m3_fp32(uint64_t desc_a, uint64_t desc_b, + uint32_t (&acc)[N / 2]) { + Qgmma_e4m3_e4m3_fp32::mma(desc_a, desc_b, acc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// QGMMA 64xNx32 TN with int32 Accumulator with A from RF and B from SMEM +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Qgmma_rfa_e4m3_e4m3_fp32 {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x32x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e4m3_e4m3_fp32<32> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[16]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15\n" + "},\n" + "{ %16, %17, %18, %19 }, %20, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x64x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e4m3_e4m3_fp32<64> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[32]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31 \n" + "},\n" + "{ %32, %33, %34, %35 }, %36, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x128x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e4m3_e4m3_fp32<128> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[64]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63 \n" + "},\n" + "{ %64, %65, %66, %67 }, %68, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x192x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e4m3_e4m3_fp32<192> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[96]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95 \n" + "},\n" + "{ %96, %97, %98, %99 }, %100, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x256x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e4m3_e4m3_fp32<256> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[128]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95,\n" + " %96, %97, %98, %99, %100, %101, %102, %103,\n" + " %104, %105, %106, %107, %108, %109, %110, %111,\n" + " %112, %113, %114, %115, %116, %117, %118, %119,\n" + " %120, %121, %122, %123, %124, %125, %126, %127 \n" + "},\n" + "{ %128, %129, %130, %131 }, %132, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]), + "+r"(acc[96]), "+r"(acc[97]), "+r"(acc[98]), "+r"(acc[99]), "+r"(acc[100]), + "+r"(acc[101]), "+r"(acc[102]), "+r"(acc[103]), "+r"(acc[104]), "+r"(acc[105]), + "+r"(acc[106]), "+r"(acc[107]), "+r"(acc[108]), "+r"(acc[109]), "+r"(acc[110]), + "+r"(acc[111]), "+r"(acc[112]), "+r"(acc[113]), "+r"(acc[114]), "+r"(acc[115]), + "+r"(acc[116]), "+r"(acc[117]), "+r"(acc[118]), "+r"(acc[119]), "+r"(acc[120]), + "+r"(acc[121]), "+r"(acc[122]), "+r"(acc[123]), "+r"(acc[124]), "+r"(acc[125]), + "+r"(acc[126]), "+r"(acc[127]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void qgmma_rfa_e4m3_e4m3_fp32(uint32_t const (&a)[4], uint64_t desc_b, + uint32_t (&acc)[N / 2]) { + Qgmma_rfa_e4m3_e4m3_fp32::mma(a, desc_b, acc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// QGMMA e4m3 x e5m2 - 64xNx32 TN with int32 Accumulator with A and B from SMEM +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Qgmma_e4m3_e5m2_fp32 {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x8x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e4m3_e5m2_fp32<8> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[4]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7\n" + "},\n" + " %8, %9, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x32x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e4m3_e5m2_fp32<32> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[16]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15 \n" + "},\n" + " %16, %17, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x64x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e4m3_e5m2_fp32<64> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[32]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31 \n" + "},\n" + " %32, %33, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x128x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e4m3_e5m2_fp32<128> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[64]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63 \n" + "},\n" + " %64, %65, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x160x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e4m3_e5m2_fp32<160> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[80]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e4m3.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79 \n" + "},\n" + " %80, %81, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x192x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e4m3_e5m2_fp32<192> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[96]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95 \n" + "},\n" + " %96, %97, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x256x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e4m3_e5m2_fp32<256> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[128]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95,\n" + " %96, %97, %98, %99, %100, %101, %102, %103,\n" + " %104, %105, %106, %107, %108, %109, %110, %111,\n" + " %112, %113, %114, %115, %116, %117, %118, %119,\n" + " %120, %121, %122, %123, %124, %125, %126, %127 \n" + "},\n" + " %128, %129, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]), + "+r"(acc[96]), "+r"(acc[97]), "+r"(acc[98]), "+r"(acc[99]), "+r"(acc[100]), + "+r"(acc[101]), "+r"(acc[102]), "+r"(acc[103]), "+r"(acc[104]), "+r"(acc[105]), + "+r"(acc[106]), "+r"(acc[107]), "+r"(acc[108]), "+r"(acc[109]), "+r"(acc[110]), + "+r"(acc[111]), "+r"(acc[112]), "+r"(acc[113]), "+r"(acc[114]), "+r"(acc[115]), + "+r"(acc[116]), "+r"(acc[117]), "+r"(acc[118]), "+r"(acc[119]), "+r"(acc[120]), + "+r"(acc[121]), "+r"(acc[122]), "+r"(acc[123]), "+r"(acc[124]), "+r"(acc[125]), + "+r"(acc[126]), "+r"(acc[127]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void qgmma_e4m3_e5m2_fp32(uint64_t desc_a, uint64_t desc_b, + uint32_t (&acc)[N / 2]) { + Qgmma_e4m3_e5m2_fp32::mma(desc_a, desc_b, acc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// QGMMA e4m3 x e5m2 - 64xNx32 TN with int32 Accumulator with A from RF and B from SMEM +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Qgmma_rfa_e4m3_e5m2_fp32 {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x8x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e4m3_e5m2_fp32<8> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[4]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7\n" + "},\n" + "{ %8, %9, %10, %11 }, %12, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x32x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e4m3_e5m2_fp32<32> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[16]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15 \n" + "},\n" + "{ %16, %17, %18, %19 }, %20, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x64x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e4m3_e5m2_fp32<64> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[32]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31 \n" + "},\n" + "{ %32, %33, %34, %35 }, %36, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x128x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e4m3_e5m2_fp32<128> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[64]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63 \n" + "},\n" + "{ %64, %65, %66, %67 }, %68, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x160x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e4m3_e5m2_fp32<160> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[80]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e4m3.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79 \n" + "},\n" + "{ %80, %81, %82, %83 }, %84, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x192x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e4m3_e5m2_fp32<192> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[96]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95 \n" + "},\n" + "{ %96, %97, %98, %99 }, %100, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x256x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e4m3_e5m2_fp32<256> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[128]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95,\n" + " %96, %97, %98, %99, %100, %101, %102, %103,\n" + " %104, %105, %106, %107, %108, %109, %110, %111,\n" + " %112, %113, %114, %115, %116, %117, %118, %119,\n" + " %120, %121, %122, %123, %124, %125, %126, %127 \n" + "},\n" + "{ %128, %129, %130, %131 }, %132, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]), + "+r"(acc[96]), "+r"(acc[97]), "+r"(acc[98]), "+r"(acc[99]), "+r"(acc[100]), + "+r"(acc[101]), "+r"(acc[102]), "+r"(acc[103]), "+r"(acc[104]), "+r"(acc[105]), + "+r"(acc[106]), "+r"(acc[107]), "+r"(acc[108]), "+r"(acc[109]), "+r"(acc[110]), + "+r"(acc[111]), "+r"(acc[112]), "+r"(acc[113]), "+r"(acc[114]), "+r"(acc[115]), + "+r"(acc[116]), "+r"(acc[117]), "+r"(acc[118]), "+r"(acc[119]), "+r"(acc[120]), + "+r"(acc[121]), "+r"(acc[122]), "+r"(acc[123]), "+r"(acc[124]), "+r"(acc[125]), + "+r"(acc[126]), "+r"(acc[127]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void qgmma_rfa_e4m3_e5m2_fp32(uint32_t const (&a)[4], uint64_t desc_b, + uint32_t (&acc)[N / 2]) { + Qgmma_rfa_e4m3_e5m2_fp32::mma(a, desc_b, acc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// QGMMA e5m2 x e4m3 - 64xNx32 TN with int32 Accumulator with A and B from SMEM +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Qgmma_e5m2_e4m3_fp32 {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x8x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e5m2_e4m3_fp32<8> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[4]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e5m2.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7\n" + "},\n" + " %8, %9, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x32x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e5m2_e4m3_fp32<32> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[16]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e5m2.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15 \n" + "},\n" + " %16, %17, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x64x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e5m2_e4m3_fp32<64> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[32]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31 \n" + "},\n" + " %32, %33, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x128x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e5m2_e4m3_fp32<128> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[64]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63 \n" + "},\n" + " %64, %65, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x160x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e5m2_e4m3_fp32<160> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[80]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e5m2.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79 \n" + "},\n" + " %80, %81, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x192x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e5m2_e4m3_fp32<192> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[96]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e5m2.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95 \n" + "},\n" + " %96, %97, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x256x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e5m2_e4m3_fp32<256> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[128]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95,\n" + " %96, %97, %98, %99, %100, %101, %102, %103,\n" + " %104, %105, %106, %107, %108, %109, %110, %111,\n" + " %112, %113, %114, %115, %116, %117, %118, %119,\n" + " %120, %121, %122, %123, %124, %125, %126, %127 \n" + "},\n" + " %128, %129, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]), + "+r"(acc[96]), "+r"(acc[97]), "+r"(acc[98]), "+r"(acc[99]), "+r"(acc[100]), + "+r"(acc[101]), "+r"(acc[102]), "+r"(acc[103]), "+r"(acc[104]), "+r"(acc[105]), + "+r"(acc[106]), "+r"(acc[107]), "+r"(acc[108]), "+r"(acc[109]), "+r"(acc[110]), + "+r"(acc[111]), "+r"(acc[112]), "+r"(acc[113]), "+r"(acc[114]), "+r"(acc[115]), + "+r"(acc[116]), "+r"(acc[117]), "+r"(acc[118]), "+r"(acc[119]), "+r"(acc[120]), + "+r"(acc[121]), "+r"(acc[122]), "+r"(acc[123]), "+r"(acc[124]), "+r"(acc[125]), + "+r"(acc[126]), "+r"(acc[127]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void qgmma_e5m2_e4m3_fp32(uint64_t desc_a, uint64_t desc_b, + uint32_t (&acc)[N / 2]) { + Qgmma_e5m2_e4m3_fp32::mma(desc_a, desc_b, acc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// QGMMA e5m2 x e4m3 - 64xNx32 TN with int32 Accumulator with A from RF and B from SMEM +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Qgmma_rfa_e5m2_e4m3_fp32 {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x8x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e5m2_e4m3_fp32<8> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[4]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e5m2.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7\n" + "},\n" + "{ %8, %9, %10, %11 }, %12, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x32x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e5m2_e4m3_fp32<32> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[16]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e5m2.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15 \n" + "},\n" + "{ %16, %17, %18, %19 }, %20, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x64x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e5m2_e4m3_fp32<64> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[32]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31 \n" + "},\n" + "{ %32, %33, %34, %35 }, %36, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x128x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e5m2_e4m3_fp32<128> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[64]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63 \n" + "},\n" + "{ %64, %65, %66, %67 }, %68, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x160x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e5m2_e4m3_fp32<160> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[80]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e5m2.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79 \n" + "},\n" + "{ %80, %81, %82, %83 }, %84, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x192x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e5m2_e4m3_fp32<192> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[96]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e5m2.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95 \n" + "},\n" + "{ %96, %97, %98, %99 }, %100, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x256x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e5m2_e4m3_fp32<256> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[128]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95,\n" + " %96, %97, %98, %99, %100, %101, %102, %103,\n" + " %104, %105, %106, %107, %108, %109, %110, %111,\n" + " %112, %113, %114, %115, %116, %117, %118, %119,\n" + " %120, %121, %122, %123, %124, %125, %126, %127 \n" + "},\n" + "{ %128, %129, %130, %131 }, %132, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]), + "+r"(acc[96]), "+r"(acc[97]), "+r"(acc[98]), "+r"(acc[99]), "+r"(acc[100]), + "+r"(acc[101]), "+r"(acc[102]), "+r"(acc[103]), "+r"(acc[104]), "+r"(acc[105]), + "+r"(acc[106]), "+r"(acc[107]), "+r"(acc[108]), "+r"(acc[109]), "+r"(acc[110]), + "+r"(acc[111]), "+r"(acc[112]), "+r"(acc[113]), "+r"(acc[114]), "+r"(acc[115]), + "+r"(acc[116]), "+r"(acc[117]), "+r"(acc[118]), "+r"(acc[119]), "+r"(acc[120]), + "+r"(acc[121]), "+r"(acc[122]), "+r"(acc[123]), "+r"(acc[124]), "+r"(acc[125]), + "+r"(acc[126]), "+r"(acc[127]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void qgmma_rfa_e5m2_e4m3_fp32(uint32_t const (&a)[4], uint64_t desc_b, + uint32_t (&acc)[N / 2]) { + Qgmma_rfa_e5m2_e4m3_fp32::mma(a, desc_b, acc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// QGMMA e5m2 x e5m2 - 64xNx32 TN with int32 Accumulator with A and B from SMEM +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Qgmma_e5m2_e5m2_fp32 {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x8x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e5m2_e5m2_fp32<8> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[4]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e5m2.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7\n" + "},\n" + " %8, %9, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x64x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e5m2_e5m2_fp32<64> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[32]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31 \n" + "},\n" + " %32, %33, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x128x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e5m2_e5m2_fp32<128> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[64]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63 \n" + "},\n" + " %64, %65, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x160x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e5m2_e5m2_fp32<160> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[80]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e5m2.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79 \n" + "},\n" + " %80, %81, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x192x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e5m2_e5m2_fp32<192> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[96]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e5m2.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95 \n" + "},\n" + " %96, %97, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x256x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e5m2_e5m2_fp32<256> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[128]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95,\n" + " %96, %97, %98, %99, %100, %101, %102, %103,\n" + " %104, %105, %106, %107, %108, %109, %110, %111,\n" + " %112, %113, %114, %115, %116, %117, %118, %119,\n" + " %120, %121, %122, %123, %124, %125, %126, %127 \n" + "},\n" + " %128, %129, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]), + "+r"(acc[96]), "+r"(acc[97]), "+r"(acc[98]), "+r"(acc[99]), "+r"(acc[100]), + "+r"(acc[101]), "+r"(acc[102]), "+r"(acc[103]), "+r"(acc[104]), "+r"(acc[105]), + "+r"(acc[106]), "+r"(acc[107]), "+r"(acc[108]), "+r"(acc[109]), "+r"(acc[110]), + "+r"(acc[111]), "+r"(acc[112]), "+r"(acc[113]), "+r"(acc[114]), "+r"(acc[115]), + "+r"(acc[116]), "+r"(acc[117]), "+r"(acc[118]), "+r"(acc[119]), "+r"(acc[120]), + "+r"(acc[121]), "+r"(acc[122]), "+r"(acc[123]), "+r"(acc[124]), "+r"(acc[125]), + "+r"(acc[126]), "+r"(acc[127]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void qgmma_e5m2_e5m2_fp32(uint64_t desc_a, uint64_t desc_b, + uint32_t (&acc)[N / 2]) { + Qgmma_e5m2_e5m2_fp32::mma(desc_a, desc_b, acc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// QGMMA e5m2 x e5m2 - 64xNx32 TN with int32 Accumulator with A from RF and B from SMEM +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Qgmma_rfa_e5m2_e5m2_fp32 {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x8x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e5m2_e5m2_fp32<8> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[4]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e5m2.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7\n" + "},\n" + "{ %8, %9, %10, %11 }, %12, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x32x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e5m2_e5m2_fp32<32> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[16]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e5m2.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15 \n" + "},\n" + "{ %16, %17, %18, %19 }, %20, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x64x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e5m2_e5m2_fp32<64> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[32]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31 \n" + "},\n" + "{ %32, %33, %34, %35 }, %36, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x128x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e5m2_e5m2_fp32<128> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[64]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63 \n" + "},\n" + "{ %64, %65, %66, %67 }, %68, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x160x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e5m2_e5m2_fp32<160> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[80]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e5m2.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79 \n" + "},\n" + "{ %80, %81, %82, %83 }, %84, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x192x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e5m2_e5m2_fp32<192> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[96]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e5m2.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95 \n" + "},\n" + "{ %96, %97, %98, %99 }, %100, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x256x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e5m2_e5m2_fp32<256> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[128]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95,\n" + " %96, %97, %98, %99, %100, %101, %102, %103,\n" + " %104, %105, %106, %107, %108, %109, %110, %111,\n" + " %112, %113, %114, %115, %116, %117, %118, %119,\n" + " %120, %121, %122, %123, %124, %125, %126, %127 \n" + "},\n" + "{ %128, %129, %130, %131 }, %132, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]), + "+r"(acc[96]), "+r"(acc[97]), "+r"(acc[98]), "+r"(acc[99]), "+r"(acc[100]), + "+r"(acc[101]), "+r"(acc[102]), "+r"(acc[103]), "+r"(acc[104]), "+r"(acc[105]), + "+r"(acc[106]), "+r"(acc[107]), "+r"(acc[108]), "+r"(acc[109]), "+r"(acc[110]), + "+r"(acc[111]), "+r"(acc[112]), "+r"(acc[113]), "+r"(acc[114]), "+r"(acc[115]), + "+r"(acc[116]), "+r"(acc[117]), "+r"(acc[118]), "+r"(acc[119]), "+r"(acc[120]), + "+r"(acc[121]), "+r"(acc[122]), "+r"(acc[123]), "+r"(acc[124]), "+r"(acc[125]), + "+r"(acc[126]), "+r"(acc[127]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void qgmma_rfa_e5m2_e5m2_fp32(uint32_t const (&a)[4], uint64_t desc_b, + uint32_t (&acc)[N / 2]) { + Qgmma_rfa_e5m2_e5m2_fp32::mma(a, desc_b, acc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/hopper/utils_tma.h b/csrc/fmha_v2/fmha/hopper/utils_tma.h new file mode 100644 index 0000000000..faa63edb81 --- /dev/null +++ b/csrc/fmha_v2/fmha/hopper/utils_tma.h @@ -0,0 +1,155 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include +#include + +namespace fmha { + +inline __device__ uint32_t elect_one_sync(); + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void utmaldg(cudaTmaDesc const* p_desc, // TMA desc + uint32_t smem_ptr, // desc smem address + uint32_t smem_barrier, // smem_barrier + int32_t const (&coord)[DIM], // coord + uint32_t elect_one = 1); + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// UTMALDG TILED WITHOUT MULTICAST +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +inline __device__ void utmaldg<2, fmha::cudaTmaDescType::TILED, false>(cudaTmaDesc const* p_desc, + uint32_t smem_ptr, + uint32_t smem_barrier, + int32_t const (&coord)[2], + uint32_t elect_one) { + if (elect_one) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + asm volatile( + "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes " + "[%0], [%1, {%2, %3}], [%4];\n" + : + : "r"(smem_ptr), "l"(reinterpret_cast(p_desc)), "r"(coord[0]), "r"(coord[1]), + "r"(smem_barrier) + : "memory"); +#endif + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +inline __device__ void utmaldg<3, fmha::cudaTmaDescType::TILED, false>(cudaTmaDesc const* p_desc, + uint32_t smem_ptr, + uint32_t smem_barrier, + int32_t const (&coord)[3], + uint32_t elect_one) { + if (elect_one) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + asm volatile( + "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes " + "[%0], [%1, {%2, %3, %4}], [%5];\n" + : + : "r"(smem_ptr), "l"(reinterpret_cast(p_desc)), "r"(coord[0]), "r"(coord[1]), + "r"(coord[2]), "r"(smem_barrier) + : "memory"); +#endif + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 4D, TILED, without Multicast +template <> +inline __device__ void utmaldg<4, fmha::cudaTmaDescType::TILED, false>(cudaTmaDesc const* p_desc, + uint32_t smem_ptr, + uint32_t smem_barrier, + int32_t const (&coord)[4], + uint32_t elect_one) { + if (elect_one) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + asm volatile( + "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes " + "[%0], [%1, {%2, %3, %4, %5}], [%6];\n" + : + : "r"(smem_ptr), "l"(reinterpret_cast(p_desc)), "r"(coord[0]), "r"(coord[1]), + "r"(coord[2]), "r"(coord[3]), "r"(smem_barrier) + : "memory"); +#endif + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// UTMASTG TILED WITHOUT MULTICAST +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void utmastg(cudaTmaDesc const* p_desc, // TMA desc + uint32_t smem_ptr, // src smem address + int32_t const (&coord)[DIM]); // coord + +// 3D, TILED +template <> +inline __device__ void utmastg<3, fmha::cudaTmaDescType::TILED>(cudaTmaDesc const* p_desc, + uint32_t smem_ptr, + const int32_t (&coord)[3]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + asm volatile( + "cp.async.bulk.tensor.3d.global.shared::cta.bulk_group [%0, {%1, %2, %3}], [%4];\n" ::"l"( + reinterpret_cast(p_desc)), + "r"(coord[0]), "r"(coord[1]), "r"(coord[2]), "r"(smem_ptr) + : "memory"); +#endif +} + +// 4D, TILED +template <> +inline __device__ void utmastg<4, fmha::cudaTmaDescType::TILED>(cudaTmaDesc const* p_desc, + uint32_t smem_ptr, + int32_t const (&coord)[4]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + asm volatile( + "cp.async.bulk.tensor.4d.global.shared::cta.bulk_group [%0, {%1, %2, %3, %4}], [%5];\n" ::"l"( + reinterpret_cast(p_desc)), + "r"(coord[0]), "r"(coord[1]), "r"(coord[2]), "r"(coord[3]), "r"(smem_ptr) + : "memory"); +#endif +} + +inline __device__ void tmastg_arrive() { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + asm volatile("cp.async.bulk.commit_group;"); +#else + assert(false); +#endif +} + +inline __device__ void tmastg_wait() { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + asm volatile("cp.async.bulk.wait_group.read %0;" : : "n"(0) : "memory"); +#else + assert(false); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/hopper/utils_warpgroup.h b/csrc/fmha_v2/fmha/hopper/utils_warpgroup.h new file mode 100644 index 0000000000..8923316f61 --- /dev/null +++ b/csrc/fmha_v2/fmha/hopper/utils_warpgroup.h @@ -0,0 +1,44 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void warpgroup_arrive() { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile("wgmma.fence.sync.aligned;\n" ::); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void warpgroup_commit() { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile("wgmma.commit_group.sync.aligned;\n" ::); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void warpgroup_wait() { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile("wgmma.wait_group.sync.aligned %0;\n" ::"n"(N)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/kernel_traits.h b/csrc/fmha_v2/fmha/kernel_traits.h new file mode 100644 index 0000000000..8e1d5cbb22 --- /dev/null +++ b/csrc/fmha_v2/fmha/kernel_traits.h @@ -0,0 +1,879 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Ada hmma/imma reuses Ampere +template +struct Traits_reuse { + using Traits = Traits_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Traits_reuse { + using Traits = fmha::Ampere_hmma_fp16_traits; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Traits_reuse { + using Traits = fmha::Ampere_hmma_fp32_traits; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Traits_reuse { + using Traits = fmha::Ampere_imma_int8_int32_traits; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Traits_o_adapter { + using Traits = Traits_p; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Traits_o_adapter { + using Traits = fmha::Volta_hmma_fp16_16x16x16_traits; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// convert to fp16 before smem_o store +template <> +struct Traits_o_adapter { + using Traits = fmha::Ampere_hmma_fp16_traits; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// convert to fp16 before smem_o store +template <> +struct Traits_o_adapter { + using Traits = fmha::Turing_hmma_fp16_traits; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// convert to bf16 before smem_o store +template <> +struct Traits_o_adapter { + using Traits = fmha::Ampere_hmma_bf16_bf16_traits; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // Instruction traits. + typename Traits_, + // The global memory tile for Q, K and V. + template class Gmem_tile_q_, + template class Gmem_tile_k_, + template class Gmem_tile_v_, + // The global memory tile for the output. + template class Gmem_tile_o_, + // Sequence length. + int S, + // The valid hidden dimension. + int VALID_D_, + // The valid hidden dimension of V. + int VALID_DV_, + // The iteration step of the outer loop. + int STEP, + // The number of vertical warps. + int WARPS_M, + // The number of horizontal warps. + int WARPS_N, + // The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K + int CTAS_PER_HEAD_, + // The flags to control the behaviour of LDGs. + uint32_t FLAGS, + // The version of the kernel. + int VERSION_, + // The mask version of the kernel + int MASK_VERSION_, + // Do we use half epilogue for the 2nd GEMM (hmma_fp32) + bool BMM2_FP16_EPILOGUE = true, + // non-positive means disabled + int SAGE_BLOCK_SIZE_Q_ = 0, int SAGE_BLOCK_SIZE_K_ = 0, int SAGE_BLOCK_SIZE_V_ = 0> +struct Kernel_traits_ { + // The instruction traits for the Q*K product. + using Traits_p = typename Traits_reuse::Traits; + // The instruction traits for the P*V product. Hack to change the traits for Volta HMMA. + using Traits_o = typename Traits_o_adapter::Traits; + // The instruction traits for the epilogue of the 2nd GEMM. Always use FP16. + using Traits_e = typename Traits_o_adapter::Traits; + + // The padded D dimension + enum { VALID_D = VALID_D_ }; + + enum { D = Next_power_of_two::VALUE }; + + enum { VALID_DV = VALID_DV_ > 0 ? VALID_DV_ : VALID_D }; + + enum { DV = Next_power_of_two::VALUE }; + + enum { + SAGE_ATTENTION = SAGE_BLOCK_SIZE_Q_ > 0 || SAGE_BLOCK_SIZE_K_ > 0 || SAGE_BLOCK_SIZE_V_ > 0 + }; + + enum { SAGE_BLOCK_SIZE_Q = SAGE_BLOCK_SIZE_Q_ }; + + enum { SAGE_BLOCK_SIZE_K = SAGE_BLOCK_SIZE_K_ }; + + enum { SAGE_BLOCK_SIZE_V = SAGE_BLOCK_SIZE_V_ }; + + // TODO: expose these tiling params to the interface + enum { USE_GRANULAR_TILING = (FLAGS & 0x1000) != 0u }; // TODO ANT: check FLAGS + + using Traits_tile_size = + Traits_tile_size<(bool)USE_GRANULAR_TILING, STEP, S, D, DV, Traits_o::K_PER_MMA>; + + enum { CTA_P_TILE_M = Traits_tile_size::CTA_P_TILE_M }; + + enum { CTA_P_TILE_N = Traits_tile_size::CTA_P_TILE_N }; + + enum { CTA_P_TILE_K = Traits_tile_size::CTA_P_TILE_K }; + + enum { CTA_O_TILE_M = Traits_tile_size::CTA_O_TILE_M }; + + enum { CTA_O_TILE_N = Traits_tile_size::CTA_O_TILE_N }; + + enum { CTA_O_TILE_K = Traits_tile_size::CTA_O_TILE_K }; + + // Do we need to reload Q due to splitting the D ? + enum { RELOAD_Q = static_cast(CTA_P_TILE_K) != static_cast(D) }; + + // The CTA description for the 1st GEMM. + using Cta_tile_p = + typename Traits_p::template Cta_tile_extd; + // The CTA description for the 2nd GEMM. + using Cta_tile_o = + typename Traits_o::template Cta_tile_extd; + + // The MMA tile for the 1st GEMM. + using Mma_tile_p = typename Traits_p::template Mma_tile; + // The MMA tile for the 2nd GEMM. + using Mma_tile_o = typename Traits_o::template Mma_tile; + + // Compute the total BMM2_MMAS_K (might not the same as Mma_tile_o::MMAS_K if the granular tiling + // is used). + static_assert(S % CTA_O_TILE_K == 0, ""); + + enum { TOTAL_BMM2_MMAS_K = Mma_tile_o::MMAS_K * (S / CTA_O_TILE_K) }; + + // Constraints on the K dimension. + static_assert(Mma_tile_p::K_PER_MMA <= static_cast(D)); + static_assert(Mma_tile_o::K_PER_MMA <= S); + + // The version. + enum { VERSION = VERSION_ }; + + // The mask version: padding (2), causal (3), sliding_window_causal (4), custom_mask (5). + enum { MASK_VERSION = MASK_VERSION_ }; + + // Whether use causal mask or not. + enum { CAUSAL_MASK = MASK_VERSION_ == 3 || MASK_VERSION_ == 4 }; + + // Whether use the sliding window attention or not. + enum { SLIDING_WINDOW_ATTENTION = MASK_VERSION_ == 4 }; + + // Whether use the custom mask or not. + enum { CUSTOM_MASK = MASK_VERSION_ == 5 }; + + // Do we use LDGSTS for Q, K or V. + enum { USE_LDGSTS_Q = (FLAGS & 0x1u) != 0u }; + + enum { USE_LDGSTS_K = (FLAGS & 0x2u) != 0u }; + + enum { USE_LDGSTS_V = (FLAGS & 0x4u) != 0u }; + + // Do we use one buffer for K and V. + enum { SHARE_SMEM_FOR_K_AND_V = (FLAGS & 0x8u) != 0u }; + + // Do we use the scale max trick. + enum { USE_SCALE_MAX = (FLAGS & 0x10u) != 0u }; + + // Are heads in QKV interleaved, i.e. total x h x 3 x d or total x 3 x h x d. + enum { HEADS_INTERLEAVED = (FLAGS & 0x20u) == 0u }; + + // Keep full K matrix in registers. + enum { K_IN_REGS = (FLAGS & 0x40) == 0u }; + + // Do we use only 2 fragments or full fragments for frag_q/k (only used by flash attention) + enum { LIMIT_QK_FRAGMENTS = ((FLAGS & 0x80u) != 0u && !SHARE_SMEM_FOR_K_AND_V) }; + + // Do we use only 2 fragments or full fragments for frag_v (only used by flash attention) + enum { LIMIT_V_FRAGMENTS = ((FLAGS & 0x100u) != 0u && !SHARE_SMEM_FOR_K_AND_V) }; + + // Limiting QK fragments implies SMEM_K has to reside in SMEM + static_assert(!(LIMIT_QK_FRAGMENTS && SHARE_SMEM_FOR_K_AND_V), ""); + + // Indicates that kernel does not loop over Q tensor, usually kernel name has _nl suffix + enum { NO_LOOP = (FLAGS & 0x200u) != 0u }; + + // Are sequences in one batch interleaved. i.e. s x b x ..., or b x s x ... + enum { SEQUENCES_INTERLEAVED = (FLAGS & 0x400) != 0u }; + + // Use BMM1 softcapping scale or not. + enum { ENABLE_BMM1_SOFTCAPPING_SCALE = (FLAGS & 0x800) != 0u }; + + // Use MTP (multi-token prediction for MLA kernels) or not. + enum { IS_MTP = (FLAGS & 0x2000) != 0u }; + + // The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K + enum { CTAS_PER_HEAD = CTAS_PER_HEAD_ }; + + // The number of shared memory buffers to build a software pipeline for Q, K and V. + enum { + BUFFERS_PER_TILE_SMEM_Q = (USE_GRANULAR_TILING && D > 64) || (USE_LDGSTS_Q && !NO_LOOP) ? 2 : 1 + }; + + enum { BUFFERS_PER_TILE_SMEM_K = USE_GRANULAR_TILING ? 2 : 1 }; + + enum { BUFFERS_PER_TILE_SMEM_V = USE_GRANULAR_TILING ? 2 : 1 }; + + // The global memory tile to load Q. + using Gmem_tile_q = Gmem_tile_q_; + + // The shared memory tile to swizzle Q. + using Smem_tile_q = fmha::Smem_tile_a; + + // The global memory tile to load K. + using Gmem_tile_k = Gmem_tile_k_; + + // The shared memory tile to swizzle K. + using Smem_tile_k = fmha::Smem_tile_b; + + // The global memory tile to load V. + using Gmem_tile_v = Gmem_tile_v_; + + // The shared memory tile to swizzle V. + using Smem_tile_v = fmha::Smem_tile_v; + + // The global memory tile to store O. + using Gmem_tile_o = Gmem_tile_o_; + // The shared memory tile for O. + using Smem_tile_o = fmha::Smem_tile_o; + + // Make sure the number of threads match. + static_assert((int)Gmem_tile_o::THREADS_PER_ROW == (int)Smem_tile_o::THREADS_PER_ROW, ""); + + // The number of threads. + enum { THREADS = Cta_tile_p::THREADS_PER_CTA }; + + // Make sure the number of threads matches both CTAs. + static_assert((int)THREADS == (int)Cta_tile_o::THREADS_PER_CTA, ""); + + // The amount of shared memory needed to load Q and K. + enum { BYTES_PER_SMEM_QK = Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE }; + + // The extra amount of shared memory needed to load V. + enum { BYTES_PER_SMEM_V = SHARE_SMEM_FOR_K_AND_V ? 0u : Smem_tile_v::BYTES_PER_TILE }; + + // The amount of shared memory needed for Q, K and V.. + enum { BYTES_PER_SMEM_QKV = BYTES_PER_SMEM_QK + BYTES_PER_SMEM_V }; + + // The amount of shared memory needed to load/store O. + enum { BYTES_PER_SMEM_O = Smem_tile_o::BYTES_PER_TILE }; + + // The amount of shared memory needed to load Q and store O. + enum { + BYTES_PER_SMEM_QO = + NO_LOOP ? Smem_tile_o::BYTES_PER_TILE : Smem_tile_q::BYTES_PER_TILE + BYTES_PER_SMEM_O + }; + + // The amount of shared memory needed for Q, K, V and O. + enum { BYTES_PER_SMEM = fmha::Max::VALUE }; + + // Make sure we have enough shared memory. + static_assert((NO_LOOP + ? Smem_tile_o::BYTES_PER_TILE + : Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE) <= BYTES_PER_SMEM, + ""); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // Instruction traits. + typename Traits_, + // The global memory tile for Q, K and V. + template class Gmem_tile_q_, + // The global memory tile for the output. + template class Gmem_tile_o_, + // Sequence length for K/V. + int S_KV, + // The hidden dimension. + int D, + // The iteration step of the outer loop. + int STEP, + // The number of vertical warps. + int WARPS_M, + // The number of horizontal warps. + int WARPS_N, + // The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K + int CTAS_PER_HEAD_, + // The flags to control the behaviour of LDGs. + uint32_t FLAGS, + // The version of the kernel. + int VERSION_, + // Do we use half epilogue for the 2nd GEMM (hmma_fp32) + bool BMM2_FP16_EPILOGUE = true> +struct Kernel_traits_fmhca_ { + // The instruction traits for the Q*K product. + using Traits_p = typename Traits_reuse::Traits; + // The instruction traits for the P*V product. Hack to change the traits for Volta HMMA. + using Traits_o = typename Traits_o_adapter::Traits; + // The instruction traits for the epilogue of the 2nd GEMM. Always use FP16. + using Traits_e = typename Traits_o_adapter::Traits; + + // The CTA description for the 1st GEMM. + using Cta_tile_p = + typename Traits_p::template Cta_tile_extd; + // The CTA description for the 2nd GEMM. + using Cta_tile_o = + typename Traits_o::template Cta_tile_extd; + + // The MMA tile for the 1st GEMM. + using Mma_tile_p = typename Traits_p::template Mma_tile; + // The MMA tile for the 2nd GEMM. + using Mma_tile_o = typename Traits_o::template Mma_tile; + + // Constraints on the K dimension. + static_assert(Mma_tile_p::K_PER_MMA <= D, ""); + static_assert(Mma_tile_o::K_PER_MMA <= S_KV, ""); + + // The version. + enum { VERSION = VERSION_ }; + + // The mask version + enum { MASK_VERSION = VERSION_ }; + + // Whether use causal mask or not. + enum { CAUSAL_MASK = MASK_VERSION >= 3 }; + + // Whether use the sliding window attention or not. + enum { SLIDING_WINDOW_ATTENTION = MASK_VERSION == 4 }; + + // Do we use LDGSTS for Q, K or V. + enum { USE_LDGSTS_Q = (FLAGS & 0x1u) != 0u }; + + enum { USE_LDGSTS_K = (FLAGS & 0x2u) != 0u }; + + enum { USE_LDGSTS_V = (FLAGS & 0x4u) != 0u }; + + // Do we use one buffer for K and V. + enum { SHARE_SMEM_FOR_K_AND_V = (FLAGS & 0x8u) != 0u }; + + // Do we use the scale max trick. + enum { USE_SCALE_MAX = (FLAGS & 0x10u) != 0u }; + + // Are heads in QKV interleaved, i.e. total x h x 3 x d or total x 3 x h x d. + enum { HEADS_INTERLEAVED = (FLAGS & 0x20u) == 0u }; + + // Keep full K matrix in registers. + enum { K_IN_REGS = (FLAGS & 0x40) == 0u }; + + // Use BMM1 softcapping scale or not. + enum { ENABLE_BMM1_SOFTCAPPING_SCALE = 0 }; + + // The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K + enum { CTAS_PER_HEAD = CTAS_PER_HEAD_ }; + + // The global memory tile to load Q. + using Gmem_tile_q = Gmem_tile_q_; + + // The shared memory tile to swizzle Q. + using Smem_tile_q = fmha::Smem_tile_a; + + // The global memory tile to load K. + using Gmem_tile_k = Gmem_tile_q_; + + // The shared memory tile to swizzle K. + using Smem_tile_k = fmha::Smem_tile_b; + + // The global memory tile to load V. + using Gmem_tile_v = Gmem_tile_q_; + + // The shared memory tile to swizzle V. + using Smem_tile_v = fmha::Smem_tile_v; + + // The global memory tile to store O. + using Gmem_tile_o = Gmem_tile_o_; + // The shared memory tile for O. + using Smem_tile_o = fmha::Smem_tile_o; + + // Make sure the number of threads match. + static_assert((int)Gmem_tile_o::THREADS_PER_ROW == (int)Smem_tile_o::THREADS_PER_ROW, ""); + + // The number of threads. + enum { THREADS = Cta_tile_p::THREADS_PER_CTA }; + + // Make sure the number of threads matches both CTAs. + static_assert((int)THREADS == (int)Cta_tile_o::THREADS_PER_CTA, ""); + + // The amount of shared memory needed to load Q and K. + enum { BYTES_PER_SMEM_QK = Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE }; + + // The extra amount of shared memory needed to load V. + enum { BYTES_PER_SMEM_V = SHARE_SMEM_FOR_K_AND_V ? 0u : Smem_tile_v::BYTES_PER_TILE }; + + // The amount of shared memory needed for Q, K and V.. + enum { BYTES_PER_SMEM_QKV = BYTES_PER_SMEM_QK + BYTES_PER_SMEM_V }; + + // The amount of shared memory needed to load Q and store O. + enum { BYTES_PER_SMEM_QO = Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE }; + + // The amount of shared memory needed for Q, K, V and O. + enum { BYTES_PER_SMEM = fmha::Max::VALUE }; + + // Make sure we have enough shared memory. + static_assert(Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE <= BYTES_PER_SMEM, ""); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits_, + // The sequence length. + int S, + // The hidden size per head. + int VALID_D, + // The number of timesteps per iteration of the main loop. + int STEP, + // The number of vertical warps. + int WARPS_M, + // The number of horizontal warps. + int WARPS_N, + // The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K + int CTAS_PER_HEAD_, + // The flags. + uint32_t FLAGS = 0x8, + // The mask version of the kernel + int MASK_VERSION_ = 2> +struct Kernel_traits_interleaved_v2_ { + // The instruction traits. + using Traits = typename Traits_reuse::Traits; + using Traits_p = Traits; + using Traits_o = Traits; + + // The padded D dimension + enum { D = Next_power_of_two::VALUE }; + + // The CTA description for the 1st GEMM. + using Cta_tile_p = + typename Traits::template Cta_tile_extd; + // The CTA description for the 2nd GEMM. + using Cta_tile_o = + typename Traits::template Cta_tile_extd; + + // The version. + enum { VERSION = 2 }; + + enum { MASK_VERSION = MASK_VERSION_ }; + + // Whether use causal mask or not. + enum { CAUSAL_MASK = MASK_VERSION_ >= 3 }; + + // Whether use the sliding window attention or not. + enum { SLIDING_WINDOW_ATTENTION = MASK_VERSION_ == 4 }; + + // The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K + enum { CTAS_PER_HEAD = CTAS_PER_HEAD_ }; + + // Do we use LDGSTS for Q, K or V. + enum { USE_LDGSTS_Q = (FLAGS & 0x1u) != 0u }; + + enum { USE_LDGSTS_K = (FLAGS & 0x2u) != 0u }; + + enum { USE_LDGSTS_V = (FLAGS & 0x4u) != 0u }; + + // Do we use one buffer for K and V. + enum { SHARE_SMEM_FOR_K_AND_V = (FLAGS & 0x8u) != 0u }; + + // Do we use the scale max trick. + enum { USE_SCALE_MAX = (FLAGS & 16) != 0u }; + + // The global memory tile to load Q. + using Gmem_tile_q = + fmha::v2::Gmem_tile_qkv_interleaved; + // The shared memory tile to swizzle Q. + using Smem_tile_q = fmha::Smem_tile_qk_interleaved_a; + + // The global memory tile to load K. + using Gmem_tile_k = + fmha::v2::Gmem_tile_qkv_interleaved; + // The shared memory tile to swizzle K. + using Smem_tile_k = fmha::Smem_tile_qk_interleaved_b; + + // The global memory tile to load V. + using Gmem_tile_v = + fmha::v2::Gmem_tile_qkv_interleaved; + + // The shared memory tile to swizzle V. + using Smem_tile_v = fmha::Smem_tile_v_interleaved_b; + + // The global memory tile to store O. + using Gmem_tile_o = fmha::v2::Imma_gmem_tile_o_interleaved; + // The shared memory tile for O. + using Smem_tile_o = fmha::Smem_tile_o_interleaved; + + // Make sure the number of threads match. + static_assert((int)Gmem_tile_o::THREADS_PER_ROW == (int)Smem_tile_o::THREADS_PER_ROW, ""); + + // The number of threads. + enum { THREADS = Cta_tile_p::THREADS_PER_CTA }; + + // Make sure the number of threads matches both CTAs. + static_assert((int)THREADS == (int)Cta_tile_o::THREADS_PER_CTA, ""); + + // The amount of shared memory needed to load Q and K. + enum { BYTES_PER_SMEM_QK = Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE }; + + // The extra amount of shared memory needed to load V. + enum { BYTES_PER_SMEM_V = SHARE_SMEM_FOR_K_AND_V ? 0u : Smem_tile_v::BYTES_PER_TILE }; + + // The amount of shared memory needed for Q, K and V.. + enum { BYTES_PER_SMEM_QKV = BYTES_PER_SMEM_QK + BYTES_PER_SMEM_V }; + + // The amount of shared memory needed to load Q and store O. + enum { BYTES_PER_SMEM_QO = Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE }; + + // The amount of shared memory needed for Q, K, V and O. + enum { BYTES_PER_SMEM = fmha::Max::VALUE }; + + // Make sure we have enough shared memory. + static_assert(Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE <= BYTES_PER_SMEM, ""); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits_, + // The sequence length. + int S, + // The hidden size per head. + int VALID_D, + // The number of timesteps per iteration of the main loop. + int STEP, + // The number of vertical warps. + int WARPS_M, + // The number of horizontal warps. + int WARPS_N, + // The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K + int CTAS_PER_HEAD_, + // The flags. + uint32_t FLAGS = 0x8, + // The mask version of the kernel + int MASK_VERSION_ = 2> +using Kernel_traits_interleaved_v2 = + Kernel_traits_interleaved_v2_; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits, + // The sequence length. + int S, + // The hidden size per head. + int D, + // The number of timesteps per iteration of the main loop. + int STEP, + // The number of vertical warps. + int WARPS_M, + // The number of horizontal warps. + int WARPS_N, + // The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K + int CTAS_PER_HEAD, + // The flags. + uint32_t FLAGS = 0x8> +using Kernel_traits_v1 = Kernel_traits_; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits, + // The sequence length. + int S, + // The hidden size per head. + int D, + // The number of timesteps per iteration of the main loop. + int STEP, + // The number of vertical warps. + int WARPS_M, + // The number of horizontal warps. + int WARPS_N, + // The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K + int CTAS_PER_HEAD, + // The flags. + uint32_t FLAGS = 0x8> +using Kernel_traits_v1_causal_mask = + Kernel_traits_; // MASK_VERSION_ + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o_dispatcher { + template + using Gmem_tile_o = fmha::v2::Gmem_tile_o; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Gmem_tile_o_dispatcher { + template + using Gmem_tile_o = fmha::v2::Gmem_tile_o_uint16; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Gmem_tile_o_dispatcher { + template + using Gmem_tile_o = fmha::v2::Gmem_tile_o_bfloat16; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits, + // The sequence length. + int S, + // The hidden size per head. + int D, + // The hidden dimension of V. + int DV, + // The number of timesteps per iteration of the main loop. + int STEP, + // The number of vertical warps. + int WARPS_M, + // The number of horizontal warps. + int WARPS_N, + // The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K + int CTAS_PER_HEAD, + // The flags. + uint32_t FLAGS = 0x8, + // The attention mask version (see src/mask.h). + int MASK_VERSION = 2, + // Do we use half epilogue for the 2nd GEMM (hmma_fp32) + bool BMM2_FP16_EPILOGUE = true, + // The output type. + typename OutputType = typename Traits::A_type, + // The sage attention block size for Q, K and V + int SAGE_BLOCK_SIZE_Q = 0, int SAGE_BLOCK_SIZE_K = 0, int SAGE_BLOCK_SIZE_V = 0> +using Kernel_traits_v2 = + Kernel_traits_::Gmem_tile_o, + S, D, DV, STEP, WARPS_M, WARPS_N, CTAS_PER_HEAD, FLAGS, 2, MASK_VERSION, + BMM2_FP16_EPILOGUE, SAGE_BLOCK_SIZE_Q, SAGE_BLOCK_SIZE_K, SAGE_BLOCK_SIZE_V>; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits, + // The sequence length. + int S, + // The hidden size per head. + int D, + // The hidden dimension of V. + int DV, + // The number of timesteps per iteration of the main loop. + int STEP, + // The number of vertical warps. + int WARPS_M, + // The number of horizontal warps. + int WARPS_N, + // The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K + int CTAS_PER_HEAD, + // The flags. + uint32_t FLAGS = 0x8, + // The attention mask version (see src/mask.h). + int MASK_VERSION = 2, + // Do we use half epilogue for the 2nd GEMM (hmma_fp32) + bool BMM2_FP16_EPILOGUE = true, + // The output type. + typename OutputType = typename Traits::A_type, + // The sage attention block size for Q, K and V + int SAGE_BLOCK_SIZE_Q = 0, int SAGE_BLOCK_SIZE_K = 0, int SAGE_BLOCK_SIZE_V = 0> +using Kernel_traits_v2_q_k_v = + Kernel_traits_::Gmem_tile_o, S, D, DV, STEP, WARPS_M, + WARPS_N, CTAS_PER_HEAD, FLAGS, 2, MASK_VERSION, BMM2_FP16_EPILOGUE, + SAGE_BLOCK_SIZE_Q, SAGE_BLOCK_SIZE_K, SAGE_BLOCK_SIZE_V>; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits, + // The sequence length. + int S, + // The hidden size per head. + int D, + // The hidden dimension of V. + int DV, + // The number of timesteps per iteration of the main loop. + int STEP, + // The number of vertical warps. + int WARPS_M, + // The number of horizontal warps. + int WARPS_N, + // The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K + int CTAS_PER_HEAD, + // The flags. + uint32_t FLAGS = 0x8, + // The attention mask version (see src/mask.h). + int MASK_VERSION = 2, + // Do we use half epilogue for the 2nd GEMM (hmma_fp32) + bool BMM2_FP16_EPILOGUE = true, + // The output type. + typename OutputType = typename Traits::A_type, + // The sage attention block size for Q, K and V + int SAGE_BLOCK_SIZE_Q = 0, int SAGE_BLOCK_SIZE_K = 0, int SAGE_BLOCK_SIZE_V = 0> +using Kernel_traits_v2_paged_kv_cache = + Kernel_traits_::Gmem_tile_o, S, D, DV, STEP, WARPS_M, + WARPS_N, CTAS_PER_HEAD, FLAGS, 2, MASK_VERSION, BMM2_FP16_EPILOGUE, + SAGE_BLOCK_SIZE_Q, SAGE_BLOCK_SIZE_K, SAGE_BLOCK_SIZE_V>; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits, + // The sequence length. + int S, + // The hidden size per head. + int D, + // The hidden dimension of V. + int DV, + // The number of timesteps per iteration of the main loop. + int STEP, + // The number of vertical warps. + int WARPS_M, + // The number of horizontal warps. + int WARPS_N, + // The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K + int CTAS_PER_HEAD, + // The flags. + uint32_t FLAGS = 0x8, + // The attention mask version (see src/mask.h). + int MASK_VERSION = 2, + // Do we use half epilogue for the 2nd GEMM (hmma_fp32) + bool BMM2_FP16_EPILOGUE = true, + // The output type. + typename OutputType = typename Traits::A_type, + // The sage attention block size for Q, K and V + int SAGE_BLOCK_SIZE_Q = 0, int SAGE_BLOCK_SIZE_K = 0, int SAGE_BLOCK_SIZE_V = 0> +using Kernel_traits_v2_contiguous_kv_cache = + Kernel_traits_::Gmem_tile_o, S, D, 0, STEP, WARPS_M, + WARPS_N, CTAS_PER_HEAD, FLAGS, 2, MASK_VERSION, BMM2_FP16_EPILOGUE, + SAGE_BLOCK_SIZE_Q, SAGE_BLOCK_SIZE_K, SAGE_BLOCK_SIZE_V>; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits, + // The sequence length for K and V. + int S_KV, + // The hidden size per head. + int D, + // The number of timesteps per iteration of the main loop. + int STEP, + // The number of vertical warps. + int WARPS_M, + // The number of horizontal warps. + int WARPS_N, + // The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K + int CTAS_PER_HEAD, + // The flags. + uint32_t FLAGS = 0x8> +using Kernel_traits_fmhca = + Kernel_traits_fmhca_; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/mask.h b/csrc/fmha_v2/fmha/mask.h new file mode 100644 index 0000000000..3219947ccf --- /dev/null +++ b/csrc/fmha_v2/fmha/mask.h @@ -0,0 +1,785 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include "fmha/traits.h" + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Mask {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Mask { + // The shape of the MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of MMAs in each dimension. + enum { MMAS_M = Mma_tile::MMAS_M }; + + enum { MMAS_N = Mma_tile::MMAS_N }; + + // Ctor. + template + inline __device__ Mask(Params const& params, Block_info const& block_info, int tidx) { + // The pointer. + packed_mask_ptr_ = reinterpret_cast(params.packed_mask_ptr); + // Take the head into account. + packed_mask_ptr_ += block_info.bidb * params.packed_mask_stride_in_bytes; + // The thread inside the CTA. + packed_mask_ptr_ += tidx * sizeof(uint32_t); + } + + // Load the mask into registers (and expand). + inline __device__ void load(int it) { + // One 32-bit integer per MMA. + uint32_t packed_mask[MMAS_M]; +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + int offset = (it * MMAS_M + mi) * Cta_tile::THREADS_PER_CTA * sizeof(uint32_t); + fmha::ldg(packed_mask[mi], packed_mask_ptr_ + offset); + } + +// Expand the mask. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + mask_[2 * mi + 0][4 * ni + 0] = packed_mask[mi] & (1u << (8 * ni + 0)); + mask_[2 * mi + 0][4 * ni + 1] = packed_mask[mi] & (1u << (8 * ni + 1)); + mask_[2 * mi + 1][4 * ni + 0] = packed_mask[mi] & (1u << (8 * ni + 2)); + mask_[2 * mi + 1][4 * ni + 1] = packed_mask[mi] & (1u << (8 * ni + 3)); + mask_[2 * mi + 0][4 * ni + 2] = packed_mask[mi] & (1u << (8 * ni + 4)); + mask_[2 * mi + 0][4 * ni + 3] = packed_mask[mi] & (1u << (8 * ni + 5)); + mask_[2 * mi + 1][4 * ni + 2] = packed_mask[mi] & (1u << (8 * ni + 6)); + mask_[2 * mi + 1][4 * ni + 3] = packed_mask[mi] & (1u << (8 * ni + 7)); + } + } + } + + // Is a given position valid? + inline __device__ bool is_valid(int mi, int ni, int ii, int jj) const { + return mask_[mi * 2 + ii][ni * 4 + jj]; + } + + // The pointer to the mask. + char const* packed_mask_ptr_; + // The mask after expansion. + bool mask_[MMAS_M * 2][MMAS_N * 4]; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Mask { + // The instruction traits. + using Traits = Volta_hmma_fp16_traits; + // The shape of the MMA tile. + using Mma_tile = typename Traits::Mma_tile; + + // The number of MMAs in each dimension. + enum { MMAS_M = Mma_tile::MMAS_M }; + + enum { MMAS_N = Mma_tile::MMAS_N }; + + // Ctor. + template + inline __device__ Mask(Params const& params, Block_info const& block_info, int tidx) { + // The pointer. + packed_mask_ptr_ = reinterpret_cast(params.packed_mask_ptr); + // Take the head into account. + packed_mask_ptr_ += block_info.bidb * params.packed_mask_stride_in_bytes; + // The thread inside the CTA. + packed_mask_ptr_ += tidx * sizeof(uint32_t); + } + + // Load the mask into registers (and expand). + inline __device__ void load(int it) { + // One 32-bit integer per MMA. + uint32_t packed_mask[MMAS_M]; +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + int offset = (it * MMAS_M + mi) * Cta_tile::THREADS_PER_CTA * sizeof(uint32_t); + fmha::ldg(packed_mask[mi], packed_mask_ptr_ + offset); + } + +// Expand the mask. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ii = 0; ii < MMAS_N * 8; ++ii) { + mask_[mi][ii] = packed_mask[mi] & (1u << ii); + } + } + } + + // Is a given position valid? + inline __device__ bool is_valid(int mi, int ni, int, int jj) const { + return mask_[mi][ni * 8 + jj]; + } + + // The pointer to the mask. + char const* packed_mask_ptr_; + // The mask after expansion. + bool mask_[MMAS_M][MMAS_N * 8]; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Mask { + // That implementation works only when WARPS_K is 1. + static_assert(Cta_tile::WARPS_K == 1, ""); + + // The shape of the MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + + // Ctor. + template + inline __device__ Mask(Params const& params, Block_info const& block_info, int tidx) + : seqlen_(block_info.actual_seqlen), col_loop_step_(0) { + // The decomposition of the thread index into warp/lane. + int warp = tidx / Cta_tile::THREADS_PER_WARP; + int lane = tidx % Cta_tile::THREADS_PER_WARP; + + // The position of the warp. + int warp_n = warp / Cta_tile::WARPS_M; + // The position of the thread. + col_ = block_info.bidn * Cta_tile::N + warp_n * 16 + lane % 4 * 2; + col_init_ = col_; + } + + // Is a given position valid? + inline __device__ bool is_valid(int, int ni, int, int jj) const { + // The position of the thread in the sequence. + int offset = this->col_ + this->col_loop_step_ * Cta_tile::N + ni * Mma_tile::N_PER_MMA_PER_CTA; + // The position inside the MMA. + offset += (jj & 0x02) * 4 + (jj & 0x1); + // Is it a valid position in the sequence? + return offset < seqlen_; + } + + // BERT Mask: if upper left is invalid, none are valid + inline __device__ bool any_valid(int mi, int ni) const { return is_valid(mi, ni, 0, 0); } + + // Move mask to next tile (flash attention) + inline __device__ void move() { this->col_ += Cta_tile::N; } + + // Move mask the col by offset (flash attention) + inline __device__ void move_to_offset(int offset) { this->col_ = col_init_ + offset; } + + // Reset mask to the initial col + inline __device__ void reset() { col_ = col_init_; } + + // Load the mask... Nothing to do for real. + inline __device__ void load(int) {} + + // Load the mask... we use it to keep track of to row, col (flash attention). + inline __device__ void load(int, int col_loop_step) { col_loop_step_ = col_loop_step; } + + // The length of the sequence. + int seqlen_; + // The left-most position of the thread in the sequence. + int col_, col_init_; + // The current col iteration + int col_loop_step_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Mask : public Mask { + // V3 mask is the causal mask (e.g. for GPT) and extends V2 masks (self-attention). + using Base = Mask; + + // The shape of the MMA tile. + using Mma_tile = typename Base::Mma_tile; + + // Ctor. + template + inline __device__ Mask(Params const& params, Block_info const& block_info, int tidx) + : Base(params, block_info, tidx), row_loop_step_(0) { + int warp = tidx / Cta_tile::THREADS_PER_WARP; + int lane = tidx % Cta_tile::THREADS_PER_WARP; + + // The position of the warp. + int warp_m = warp % Cta_tile::WARPS_M; + row_ = warp_m * 16 + lane / 4; + } + + inline __device__ void get_row_col(int& row, int& col, int mi, int ni, int ii, int jj) const { + // The position of the thread in the sequence. + row = this->row_ + this->row_loop_step_ + mi * Mma_tile::M_PER_MMA_PER_CTA; + // The position inside the MMA. + row += ii * 8; + + // The position of the thread in the sequence. + col = this->col_ + this->col_loop_step_ * Cta_tile::N + ni * Mma_tile::N_PER_MMA_PER_CTA; + // The position inside the MMA. + col += (jj & 0x02) * 4 + (jj & 0x1); + } + + // Is a given position valid? + inline __device__ bool is_valid(int mi, int ni, int ii, int jj) const { + int row, col; + get_row_col(row, col, mi, ni, ii, jj); + + // Is it a valid position in the sequence? + return is_valid(row, col); + } + + // Is a given position valid? + inline __device__ bool is_valid(int row, int col) const { + // Is it a valid position in the sequence, i.e. are we in the lower triangle? + return (row >= col); + } + + // GPT Mask: if lower left is invalid, none are valid + inline __device__ bool any_valid(int mi, int ni) const { return is_valid(mi, ni, 1, 0); } + + // Load the mask... we use it to keep track of to row. + inline __device__ void load(int row_loop_step) { row_loop_step_ = row_loop_step; } + + // Load the mask... we use it to keep track of to row, col (flash attention). + inline __device__ void load(int row_loop_step, int col_loop_step) { + row_loop_step_ = row_loop_step; + this->col_loop_step_ = col_loop_step; + } + + // The upper-most position of the thread in the sequence. + int row_; + // Current row step offset. + int row_loop_step_; +}; + +// Specialized mask for MTP (multi-token prediction used in MLA). +template +struct MtpMask : public Mask { + // MTP mask (causal mask) extends from V2 (dense) masks (self-attention). + using Base = Mask; + + // The shape of the MMA tile. + using Mma_tile = typename Base::Mma_tile; + + // Ctor. + template + inline __device__ MtpMask(Params const& params, Block_info const& block_info, int tidx) + : Base(params, block_info, tidx), + num_grouped_heads_(params.num_grouped_heads), + row_loop_step_(0) { + // Update the seqlen (excluding all MTP draft tokens). + this->seqlen_ = this->seqlen_ - (block_info.actual_q_seqlen / params.num_grouped_heads) + 1; + + int warp = tidx / Cta_tile::THREADS_PER_WARP; + int lane = tidx % Cta_tile::THREADS_PER_WARP; + + // The position of the warp. + int warp_m = warp % Cta_tile::WARPS_M; + row_ = warp_m * 16 + lane / 4; + } + + inline __device__ int get_row(int mi, int ii) const { + // The position of the thread in the sequence. + int row = this->row_ + this->row_loop_step_ + mi * Mma_tile::M_PER_MMA_PER_CTA; + // The position inside the MMA. + row += ii * 8; + return row; + } + + inline __device__ int get_col(int ni, int jj) const { + // The position of the thread in the sequence. + int col = this->col_ + this->col_loop_step_ * Cta_tile::N + ni * Mma_tile::N_PER_MMA_PER_CTA; + // The position inside the MMA. + col += (jj & 0x02) * 4 + (jj & 0x1); + return col; + } + + inline __device__ void get_row_col(int& row, int& col, int mi, int ni, int ii, int jj) const { + row = get_row(mi, ii); + col = get_col(ni, jj); + } + + // Is a given position valid? + inline __device__ bool is_valid(int mi, int ni, int ii, int jj) const { + int col = get_col(ni, jj); + + // Is it a valid position in the sequence? + return col < (this->seqlen_ + mtp_token_idx_[mi][ii]); + } + + // Is a given position valid? + inline __device__ bool is_valid(int row, int col) const { + // Is it a valid position in the sequence, i.e. are we in the lower triangle? + return (row >= col); + } + + // Load the mask... we use it to keep track of to row. + inline __device__ void load(int row_loop_step) { + row_loop_step_ = row_loop_step; +// Update the MTP token index. +#pragma unroll + for (int mi = 0; mi < Mma_tile::MMAS_M; ++mi) { +#pragma unroll + for (int ii = 0; ii < 2; ++ii) { + mtp_token_idx_[mi][ii] = get_row(mi, ii) / num_grouped_heads_; + } + } + } + + // The number of grouped heads in the row dimension. + int num_grouped_heads_; + // The corresponding MTP token index for each row. + // FIXME: currently we assume 2 rows per thread (volta/hopper-gmma traits are not supported yet). + int mtp_token_idx_[Mma_tile::MMAS_M][2]; + // The upper-most position of the thread in the sequence. + int row_; + // The current row step offset. + int row_loop_step_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// The lower triangle attention matrix. +// Assume we only pay attention to past sliding-window-size long sequence. +// v x x x x x x x x +// v v x x x x x x x +// v v v x x x x x x +// v v v v x x x x x +// v v v v v x x x x +// x v v v v v x x x +// x x v v v v v x x +// x x x v v v v v x +// x x x x v v v v v + +template +struct Mask : public Mask { + // V4 mask is the causal mask (e.g. for GPT) plus the sliding-window feature. + using Base = Mask; + + // The shape of the MMA tile. + using Mma_tile = typename Base::Mma_tile; + + // Ctor. + template + inline __device__ Mask(Params const& params, Block_info const& block_info, int tidx) + : Base(params, block_info, tidx), sliding_window_size_(params.sliding_window_size) {} + + // Is a given position valid? + inline __device__ bool is_valid(int mi, int ni, int ii, int jj) const { + int row, col; + this->get_row_col(row, col, mi, ni, ii, jj); + + // Is it a valid position in the sequence? + return is_valid(row, col); + } + + // Is a given position valid? + inline __device__ bool is_valid(int row, int col) const { + // Is it a valid position in the sequence, i.e. are we in the lower triangle? + return (row >= col) && (col >= max(0, row + 1 - sliding_window_size_)); + } + + // The sliding window size. + int sliding_window_size_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// The custom mask (from global memory). +template +struct Mask : public Mask { + using Base = Mask; + + // The shape of the MMA tile. + using Mma_tile = typename Base::Mma_tile; + + // The number of MMAs in each dimension. + enum { MMAS_M = Mma_tile::MMAS_M }; + + enum { MMAS_N = Mma_tile::MMAS_N }; + + // One 32-bit packed mask holds 4 MMAS_N as one group. + enum { MMA_GROUPS_N = fmha::Div_up::VALUE }; + + // The MMAS_N in the group. + enum { MMAS_N_IN_GROUP = fmha::Min::VALUE }; + + // MMAS_N uses full 32-bit integer packed masks. + enum { FULL_PACKED_MASK = (MMAS_N % 4 == 0) }; + + // Ctor. + template + inline __device__ Mask(Params const& params, Block_info const& block_info, int tidx) + : Base(params, block_info, tidx), + packed_mask_ptr_(reinterpret_cast(params.packed_mask_ptr)), + params_packed_mask_stride_in_bytes_(params.packed_mask_stride_in_bytes), + row_offset_(0) { + // Add the thread offset in bytes. + packed_mask_ptr_ += + (block_info.sum_mask_row * params_packed_mask_stride_in_bytes_ + tidx * sizeof(uint32_t)); + } + + // Load the mask... we use it to keep track of row offset. + inline __device__ void load(int row_offset) { row_offset_ = row_offset; } + + // Load the mask into registers (and expand). + inline __device__ void load_mask(int col_offset) { + // The packed_mask_offset in the col(N) dimension. + int mask_col_offset = int(col_offset / (Mma_tile::N_PER_MMA_PER_CTA * 4)) * + Cta_tile::THREADS_PER_CTA * sizeof(uint32_t); + // When MMAS_N < 4, one loaded packed_mask can be expanded to boolean masks + // of multiple iterations. + int local_col = FULL_PACKED_MASK ? 0 : (col_offset % (Mma_tile::N_PER_MMA_PER_CTA * 4)); + // The local mma ni if MMAS_N < 4. + int local_ni = local_col / 16; +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // The M dimension offset. + int offset = + (row_offset_ + mi * Mma_tile::M_PER_MMA_PER_CTA) * params_packed_mask_stride_in_bytes_; + // The N dimension offset. + offset += mask_col_offset; + // Set predicate to true only when next 32-bit packed mask is needed. + bool pred = local_col == 0; +#pragma unroll + for (int ni = 0; ni < MMA_GROUPS_N; ++ni) { + // The MMAS_N group offset. + if (pred) { + fmha::ldg(packed_mask_[mi][ni], + packed_mask_ptr_ + offset + ni * Cta_tile::THREADS_PER_CTA * sizeof(uint32_t)); + } + } + } + +// Expand the mask. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMA_GROUPS_N; ++ni) { +#pragma unroll + for (int nni = 0; nni < MMAS_N_IN_GROUP; ++nni) { + mask_[2 * mi + 0][(ni * 4 + nni) * 4 + 0] = + packed_mask_[mi][ni] & (1u << (8 * (nni + local_ni) + 0)); + mask_[2 * mi + 0][(ni * 4 + nni) * 4 + 1] = + packed_mask_[mi][ni] & (1u << (8 * (nni + local_ni) + 1)); + mask_[2 * mi + 1][(ni * 4 + nni) * 4 + 0] = + packed_mask_[mi][ni] & (1u << (8 * (nni + local_ni) + 2)); + mask_[2 * mi + 1][(ni * 4 + nni) * 4 + 1] = + packed_mask_[mi][ni] & (1u << (8 * (nni + local_ni) + 3)); + mask_[2 * mi + 0][(ni * 4 + nni) * 4 + 2] = + packed_mask_[mi][ni] & (1u << (8 * (nni + local_ni) + 4)); + mask_[2 * mi + 0][(ni * 4 + nni) * 4 + 3] = + packed_mask_[mi][ni] & (1u << (8 * (nni + local_ni) + 5)); + mask_[2 * mi + 1][(ni * 4 + nni) * 4 + 2] = + packed_mask_[mi][ni] & (1u << (8 * (nni + local_ni) + 6)); + mask_[2 * mi + 1][(ni * 4 + nni) * 4 + 3] = + packed_mask_[mi][ni] & (1u << (8 * (nni + local_ni) + 7)); + } + } + } + } + + // Move mask the col by offset (flash attention) + inline __device__ void move_to_offset(int col_offset) { load_mask(col_offset); } + + // Is a given position valid? + inline __device__ bool is_valid(int mi, int ni, int ii, int jj) const { + return mask_[mi * 2 + ii][ni * 4 + jj]; + } + + // Current row step offset. + int row_offset_; + + // The pointer to the mask. + char const* packed_mask_ptr_; + // The stride in the n dimension. + int64_t const params_packed_mask_stride_in_bytes_; + // The packed mask (one 32-bit integer per MMA GROUP, MMAS_M * 2 rows, MMA_GROUPS_N * 16 cols). + uint32_t packed_mask_[MMAS_M][MMA_GROUPS_N]; + // The mask after expansion. + bool mask_[MMAS_M * 2][MMAS_N * 4]; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Mask { + // The instruction traits. + using Traits = Volta_hmma_fp16_traits; + // The shape of the MMA tile. + using Mma_tile = typename Traits::Mma_tile; + + // That implementation works only when WARPS_K is 1. + static_assert(Cta_tile::WARPS_K == 1, ""); + + // Ctor. + template + inline __device__ Mask(Params const& params, Block_info const& block_info, int tidx) + : seqlen_(block_info.actual_seqlen) { + // The decomposition of the thread index into warp/lane. + int warp = tidx / Cta_tile::THREADS_PER_WARP; + int lane = tidx % Cta_tile::THREADS_PER_WARP; + + // The position of the warp. + int warp_n = warp / Cta_tile::WARPS_M; + // The position of the thread. + col_ = block_info.bidn * Cta_tile::N + warp_n * 16 + (lane & 0x08) / 2; + col_init_ = col_; + } + + // Is a given position valid? + inline __device__ bool is_valid(int, int ni, int, int jj) const { + // The position of the thread in the sequence. + int offset = this->col_ + ni * Mma_tile::N_PER_MMA_PER_CTA; + // The position inside the MMA. + offset += (jj & 0x04) * 2 + (jj & 0x03); + // Is it a valid position in the sequence? + return offset < seqlen_; + } + + // Load the mask... Nothing to do for real. + inline __device__ void load(int) {} + + // Reset mask to the initial col + inline __device__ void reset() { col_ = col_init_; } + + // Move mask to next tile (flash attention) + inline __device__ void move() { this->col_ += Cta_tile::N; } + + // Move mask the col by offset (flash attention) + inline __device__ void move_to_offset(int offset) { this->col_ = col_init_ + offset; } + + // The length of the sequence. + int const seqlen_; + // The left-most position of the thread in the sequence. + int col_, col_init_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Mask + : public Mask { + // V3 mask is the causal mask (e.g. for GPT) and extends V2 masks (self-attention). + using Base = Mask; + + // The shape of the MMA tile. + using Mma_tile = typename Base::Mma_tile; + + // Ctor. + template + inline __device__ Mask(Params const& params, Block_info const& block_info, int tidx) + : Base(params, block_info, tidx), loop_step_(0) { + int warp = tidx / Cta_tile::THREADS_PER_WARP; + int lane = tidx % Cta_tile::THREADS_PER_WARP; + + // The position of the warp. + int warp_m = warp % Cta_tile::WARPS_M; + row_ = warp_m * 16 + (lane & 0x07) + (lane & 0x10) / 2; + } + + inline __device__ void get_row_col(int& row, int& col, int mi, int ni, int ii, int jj) const { + // The position of the thread in the sequence. + row = this->row_ + this->loop_step_ + mi * Mma_tile::M_PER_MMA_PER_CTA; + + // The position of the thread in the sequence. + col = this->col_ + ni * Mma_tile::N_PER_MMA_PER_CTA; + // The position inside the MMA. + col += (jj & 0x04) * 2 + (jj & 0x03); + } + + // Is a given position valid? + inline __device__ bool is_valid(int mi, int ni, int ii, int jj) const { + int row, col; + get_row_col(row, col, mi, ni, ii, jj); + + // Is it a valid position in the sequence? + return is_valid(row, col); + } + + // Is a given position valid? + inline __device__ bool is_valid(int row, int col) const { + // Is it a valid position in the sequence, i.e. are we in the lower triangle? + return (row >= col) && (col < this->seqlen_); + } + + // GPT Mask: if lower left is invalid, none are valid + inline __device__ bool any_valid(int mi, int ni) const { return is_valid(mi, ni, 0, 0); } + + // Load the mask... we use it to keep track of to row. + inline __device__ void load(int loop_step) { loop_step_ = loop_step; } + + // The upper-most position of the thread in the sequence. + int row_; + // Current iteration. + int loop_step_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Mask_dispatcher {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Mask_dispatcher + : public Mask { + using Base = Mask; + + template + inline __device__ Mask_dispatcher(Params const& params, Block_info const& block_info, int tidx) + : Base(params, block_info, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Mask_dispatcher : public MtpMask { + using Base = MtpMask; + + template + inline __device__ Mask_dispatcher(Params const& params, Block_info const& block_info, int tidx) + : Base(params, block_info, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Mask_hopper { + // The shape of the MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + + // Ctor. + template + inline __device__ Mask_hopper(Params const& params, Block_info const& block_info, int tidx) + : seqlen_(block_info.actual_seqlen) { + // For Hopper the warp distribution is always 4x1 within a warpgroup. + // So maybe there is some assumptions/optimizations to be made here. + + int warp = tidx / Cta_tile::THREADS_PER_WARP; + int warp_n = warp / 4; + int lane = tidx % Cta_tile::THREADS_PER_WARP; + col_ = warp_n * Mma_tile::N_PER_WARP_GROUP + (lane % 4) * 2; + } + + // Is a given position valid? + inline __device__ bool is_valid(int, int ni, int, int jj) const { + // The position of the thread in the sequence. + int offset = this->col_ + ni * Mma_tile::N_PER_MMA; + // The position inside the MMA. + offset += (jj / 2) * 8 + (jj % 2); + // Is it a valid position in the sequence? + return offset < seqlen_; + } + + // Load the mask... Nothing to do for real. + inline __device__ void load(int) {} + + // The length of the sequence. + int const seqlen_; + // The left-most position of the thread in the sequence. + int col_; +}; + +template +struct Mask_hopper { + // The shape of the MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + + // Ctor. + template + inline __device__ Mask_hopper(Params const& params, Block_info const& block_info, int tidx) { + // For Hopper the warp distribution is always 4x1 within a warpgroup. + // So maybe there is some assumptions/optimizations to be made here. + + int warp = tidx / Cta_tile::THREADS_PER_WARP; + int warp_n = warp / 4; + int warp_m = warp % 4; + int lane = tidx % Cta_tile::THREADS_PER_WARP; + col_ = warp_n * Mma_tile::N_PER_WARP_GROUP + (lane % 4) * 2; + row_base_ = warp_m * 16 + lane / 4; + row_ = row_base_; + } + + inline __device__ void get_row_col(int& row, int& col, int mi, int ni, int ii, int jj) const { + // The row position of the thread in the sequence. + row = row_ + mi * Mma_tile::M_PER_MMA + ii * 8; + + // The position of the thread in the sequence. + col = this->col_ + ni * Mma_tile::N_PER_MMA; + // The position inside the MMA. + col += (jj / 2) * 8 + (jj % 2); + } + + // Is a given position valid? + inline __device__ bool is_valid(int mi, int ni, int ii, int jj) const { + int row, col; + get_row_col(row, col, mi, ni, ii, jj); + + // Is it a valid position in the sequence? + return is_valid(row, col); + } + + // Is a given position valid? + inline __device__ bool is_valid(int row, int col) const { + // Is it a valid position in the sequence? + return col <= row; + } + + // Load the mask... Nothing to do for real. + inline __device__ void load(int loop_step) { row_ = row_base_ + loop_step * Cta_tile::M; } + + // The left-most position of the thread in the sequence. + int row_, row_base_, col_; +}; + +template +struct Mask_hopper : public Mask_hopper { + // V4 mask is the causal mask (e.g. for GPT) plus the sliding-window feature. + using Base = Mask_hopper; + + // The shape of the MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + + // Ctor. + template + inline __device__ Mask_hopper(Params const& params, Block_info const& block_info, int tidx) + : Base(params, block_info, tidx), sliding_window_size_(params.sliding_window_size) {} + + // Is a given position valid? + inline __device__ bool is_valid(int mi, int ni, int ii, int jj) const { + int row, col; + this->get_row_col(row, col, mi, ni, ii, jj); + + // Is it a valid position in the sequence? + return is_valid(row, col); + } + + // Is a given position valid? + inline __device__ bool is_valid(int row, int col) const { + // Is it a valid position in the sequence? + return col <= row && col >= max(0, row + 1 - sliding_window_size_); + } + + // The sliding window size for attention. + int sliding_window_size_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/numeric_types.h b/csrc/fmha_v2/fmha/numeric_types.h new file mode 100644 index 0000000000..1c3ec1a615 --- /dev/null +++ b/csrc/fmha_v2/fmha/numeric_types.h @@ -0,0 +1,57 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include + +#include + +#pragma once + +#if CUDART_VERSION >= 11080 +// TODO Better way? +#define FMHA_CUDA_SUPPORTS_FP8 true +#endif +#include +#if FMHA_CUDA_SUPPORTS_FP8 +#include +#endif +namespace fmha { + +using fp16_t = uint16_t; +using fp32_t = float; +using tf32_t = uint32_t; +using bf16_t = nv_bfloat16; +#if FMHA_CUDA_SUPPORTS_FP8 +using e4m3_t = __nv_fp8_e4m3; +using e5m2_t = __nv_fp8_e5m2; +#else +using e4m3_t = char; +using e5m2_t = char; +#endif + +static constexpr float MAX_E4M3 = 448.f; // 0x7E 2^8 * 1.75 +static constexpr float MAX_E5M2 = 57344.f; // 0x7B 2^15 * 1.75 + +template +__host__ __device__ constexpr inline float Softmax_fp_quant_scale(); + +template <> +__host__ __device__ constexpr inline float Softmax_fp_quant_scale() { + // Softmax has max output of 1.0, therefore we choose fp32-to-fp8 quantization scale as the + // largest power-of-2 below the e4m3 limit: + // 2^(floor(log2(E4M3_MAX / amax_exp_p))) = 2^(floor(log2(448 / 1))) = 2 ^ 8 + return 256.f; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/paged_kv_cache.h b/csrc/fmha_v2/fmha/paged_kv_cache.h new file mode 100644 index 0000000000..a8e13a61d0 --- /dev/null +++ b/csrc/fmha_v2/fmha/paged_kv_cache.h @@ -0,0 +1,63 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include +#include + +namespace fmha { + +// This needs to be aligned with the definition in TRT-LLM +struct Kv_block_array { + using PtrType = int32_t; + + // Maximum number of sequences supported by the kv-cache. + int32_t mMaxSeqs; + // Max number of blocks per sequence + int32_t mMaxBlocksPerSeq; + // Number of tokens. It must be power of 2. + int32_t mTokensPerBlock; + // Exponent of number of tokens with base 2. + // E.g. for mTokensPerBlock 64, mTokensPerBlockLog2 equals to 6 + int32_t mTokensPerBlockLog2; + // Table maps logical block idx to the data pointer of k/v cache block pool + // Shape [B, W, 2, M], where 2 is table for K and V, + // B is current number of sequences + // W is beam width + // M is Max number of blocks per sequence + + // Size of KV cache blocks in bytes (H*D*T*sizeof(DataType)) + int32_t mBytesPerBlock; + // Pointer to beginning of pool. + void* mPoolPtr; + // Pointer to block offsets. + PtrType* mBlockOffsets; + + Kv_block_array() = default; + + Kv_block_array(int32_t batchSize, int32_t maxBlocksPerSeq, int32_t tokensPerBlock, + int32_t bytesPerBlock, void* poolPtr) + : mMaxSeqs(batchSize), + mMaxBlocksPerSeq(maxBlocksPerSeq), + mTokensPerBlock(tokensPerBlock), + mBytesPerBlock{bytesPerBlock}, + mPoolPtr{poolPtr}, + mBlockOffsets{nullptr} { + float const tokensPerBlockSeqLog2 = log2(mTokensPerBlock); + mTokensPerBlockLog2 = static_cast(tokensPerBlockSeqLog2); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/smem_tile.h b/csrc/fmha_v2/fmha/smem_tile.h new file mode 100644 index 0000000000..dd75cf7bdb --- /dev/null +++ b/csrc/fmha_v2/fmha/smem_tile.h @@ -0,0 +1,2071 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include +#include +#include + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The description of the tile computed by this CTA. + typename Cta_tile, + // The number of rows in the 2D shared memory buffer. + int M_, + // The number of cols. + int N_, + // The size in bits of each element. + int BITS_PER_ELEMENT_, + // The number of bytes per STS. + int BYTES_PER_STS_ = 16, + // The number of buffers. (Used in multistage and double buffer cases.) + int BUFFERS_PER_TILE_ = 1, + // Do we enable the fast path for LDS.128 and friends. + int ENABLE_LDS_FAST_PATH_ = 0, + // The number of rows that are used for the XOR swizzling to allow fast STS/LDS. + int ROWS_PER_XOR_PATTERN_ = 8, + // The number of cols that are used for the XOR swizzling to allow fast STS/LDS. + int COLS_PER_XOR_PATTERN_ = 1, + // Use or not predicates + bool USE_PREDICATES_ = true, + // Use TMA or not, + bool USE_TMA_ = false, + // The leading dim elements in shared memory + int LEAD_DIM_ELEMENTS_ = N_> +struct Smem_tile_without_skews { + // The type of this tile + using Smem_tile_ = + Smem_tile_without_skews; + + static constexpr bool USE_TMA = USE_TMA_; + + // The size in bits of each element. + enum { BITS_PER_ELEMENT = BITS_PER_ELEMENT_ }; + + // The size in bytes of a single STS. + enum { BYTES_PER_STS = BYTES_PER_STS_ }; + + // The number of elements per STS. + enum { ELEMENTS_PER_STS = BYTES_PER_STS * 8 / BITS_PER_ELEMENT }; + + // To support arbitrary N, we pad some values to a power-of-2. + enum { N_WITH_PADDING = Next_power_of_two::VALUE }; + + // The number of bytes per row without packing of rows. + enum { BYTES_PER_ROW_BEFORE_PACKING = N_WITH_PADDING * BITS_PER_ELEMENT / 8 }; + + // The number of bytes per row -- we want at least 128B per row. + enum { BYTES_PER_ROW = Max::VALUE }; + + // The number of rows in shared memory (two rows may be packed into a single one). + enum { ROWS = M_ * N_ / LEAD_DIM_ELEMENTS_ * BYTES_PER_ROW_BEFORE_PACKING / BYTES_PER_ROW }; + + // The number of threads per row. + enum { THREADS_PER_ROW_UNBOUNDED = BYTES_PER_ROW / BYTES_PER_STS }; + + // The number of threads per row. + enum { THREADS_PER_ROW = Min::VALUE }; + + // The number of STS per row. + enum { STS_PER_ROW = BYTES_PER_ROW / THREADS_PER_ROW / BYTES_PER_STS }; + + // It must be at least one. + static_assert(STS_PER_ROW >= 1, ""); + + // The number of rows written with a single STS. + enum { ROWS_PER_STS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + + // Make sure we write to at least one row per STS. Thanks Dr. Obvious ;) + static_assert(ROWS_PER_STS >= 1, ""); + + // The number of STS needed to store all rows. + enum { STS_PER_COL = Div_up::VALUE }; + + // The number of STS in total. + enum { STS = STS_PER_COL * STS_PER_ROW }; + + // The size of one buffer in bytes in shared memory. + enum { BYTES_PER_BUFFER = STS * BYTES_PER_STS * Cta_tile::THREADS_PER_CTA }; + + // The number of buffers. + enum { BUFFERS_PER_TILE = BUFFERS_PER_TILE_ }; + + // The size in bytes of total buffers. + enum { BYTES_PER_TILE = BYTES_PER_BUFFER * BUFFERS_PER_TILE }; + + // The boundary for smem_read_offset and smem_write_offset increment. + enum { BYTES_PER_TILE_INC_BOUNDARY = BYTES_PER_TILE - BYTES_PER_BUFFER }; + + // The number of rows that are used for the XOR swizzling to allow fast STS/LDS. + enum { ROWS_PER_XOR_PATTERN = ROWS_PER_XOR_PATTERN_ }; + + // The number of cols that are used for the XOR swizzling to allow fast STS/LDS. + enum { COLS_PER_XOR_PATTERN = COLS_PER_XOR_PATTERN_ * 16 / BYTES_PER_STS }; + + // Use or not predicates + enum { USE_PREDICATES = USE_PREDICATES_ }; + + // The bytes of one shmem row + enum { BYTES_PER_SHMEM_ROW = 128 }; + + // The type of elements that are stored in shared memory by each thread. + using Store_type = typename Uint_from_size_in_bytes::Type; + + // Ctor. + inline __device__ Smem_tile_without_skews(void* smem, int tidx) + : smem_(__nvvm_get_smem_pointer(smem)) { + // The row written by a thread. See doc/mma_smem_layout.xlsx. + int smem_write_row = tidx / THREADS_PER_ROW; + + // The XOR pattern. + int smem_write_xor = smem_write_row % ROWS_PER_XOR_PATTERN * COLS_PER_XOR_PATTERN; + // Compute the column and apply the XOR pattern. + int smem_write_col = (tidx % THREADS_PER_ROW) ^ smem_write_xor; + + // The offset. + this->smem_write_offset_ = smem_write_row * BYTES_PER_ROW + smem_write_col * BYTES_PER_STS; + + // That code is expected to trigger the utilization of the URF by the compiler. + this->smem_read_buffer_ = __shfl_sync(0xffffffff, 0, 0); + this->smem_write_buffer_ = __shfl_sync(0xffffffff, 0, 0); + } + + // Compute the store pointers. + template + inline __device__ void compute_store_pointers(uint32_t (&ptrs)[N]) { +#pragma unroll + for (int ii = 0; ii < N; ++ii) { + // Decompose the STS into row/col. + int row = ii % STS_PER_COL; + int col = ii / STS_PER_COL; + + // Compute the immediate. + int imm = row; + + // Assemble the offset. + int offset = smem_write_offset_ + imm * ROWS_PER_STS * BYTES_PER_ROW; + + // Take the column into account. + if (STS_PER_ROW > 1) { + offset += col * THREADS_PER_ROW * BYTES_PER_STS; + } + + // Apply the XOR pattern if needed. + if (ROWS_PER_STS < ROWS_PER_XOR_PATTERN) { + int const m = row * ROWS_PER_STS % ROWS_PER_XOR_PATTERN; + offset ^= m * COLS_PER_XOR_PATTERN * BYTES_PER_STS; + } + +// Assemble the final pointer :) +#pragma unroll + for (int k = 0; k < K; k++) { + ptrs[ii * K + k] = smem_ + offset + k * (BYTES_PER_STS / K) + smem_write_buffer_; + } + } + } + + inline __device__ void debug_reset() { + for (int buffer = 0; buffer < BYTES_PER_TILE; buffer += BYTES_PER_BUFFER) { + for (int row = 0; row < ROWS; ++row) { + for (int col = 0; col < BYTES_PER_ROW; col += 4) { + if (threadIdx.x == 0) { + uint32_t val = 0x0; + sts(val, smem_ + row * BYTES_PER_ROW + col + buffer); + } + } + } + } + } + + // Print the content of the tile (only for debug ;)). + inline __device__ void debug_print() const { + for (int buffer = 0; buffer < BYTES_PER_TILE; buffer += BYTES_PER_BUFFER) { + for (int row = 0; row < ROWS; ++row) { + for (int col = 0; col < BYTES_PER_ROW; col += 4) { + if (threadIdx.x == 0) { + uint32_t val; + lds(val, smem_ + row * BYTES_PER_ROW + col + buffer); + printf( + "block=(x=%2d, y=%2d, z=%2d) (smem_=0x%08x, buffer=%2d, row=%2d, " + "byte=%4d)=0x%08x\n", + blockIdx.x, blockIdx.y, blockIdx.z, smem_, buffer, row, col, val); + } + } + } + } + } + + // Move the read offset to next buffer. + inline __device__ void move_to_next_read_buffer() { + if (BUFFERS_PER_TILE > 1 && smem_read_buffer_ >= BYTES_PER_TILE_INC_BOUNDARY) { + this->smem_read_buffer_ -= BYTES_PER_TILE_INC_BOUNDARY; + } else if (BUFFERS_PER_TILE > 1) { + this->smem_read_buffer_ += BYTES_PER_BUFFER; + } + } + + // Move the read offset to next buffer. TODO: Remove this member function!!! + inline __device__ void move_next_read_buffer() { this->move_to_next_read_buffer(); } + + // Move the read offset to next N buffer (circular-buffer). + inline __device__ void move_to_next_read_buffer(int N) { + if (BUFFERS_PER_TILE > 1) { + this->smem_read_buffer_ += N * BYTES_PER_BUFFER; + this->smem_read_buffer_ -= smem_read_buffer_ >= BYTES_PER_TILE ? BYTES_PER_TILE : 0; + } + } + + // Move the read offset to next N buffer (circular-buffer). TODO: Remove this member function!!! + inline __device__ void move_next_read_buffer(int N) { this->move_to_next_read_buffer(N); } + + // Move the write offset to next buffer. + inline __device__ void move_to_next_write_buffer() { + if (BUFFERS_PER_TILE > 1 && smem_write_buffer_ >= BYTES_PER_TILE_INC_BOUNDARY) { + this->smem_write_buffer_ -= BYTES_PER_TILE_INC_BOUNDARY; + } else if (BUFFERS_PER_TILE > 1) { + this->smem_write_buffer_ += BYTES_PER_BUFFER; + } + } + + // Move the write offset to next buffer. TODO: Remove that member function! + inline __device__ void move_next_write_buffer() { this->move_to_next_write_buffer(); } + + // Move the read offset. + inline __device__ void move_read_offset(int delta) { this->smem_read_offset_ += delta; } + + // Move the write offset. + inline __device__ void move_write_offset(int delta) { this->smem_write_offset_ += delta; } + + // Store to the tile in shared memory. + template + inline __device__ void store(Store_type const (&data)[N]) { + uint32_t smem_ptrs[N]; + this->compute_store_pointers(smem_ptrs); + sts(smem_ptrs, data); + } + + // Store to the tile in shared memory. + template + inline __device__ void store(Store_type const (&data)[N], uint32_t (&preds)[M]) { + uint32_t smem_ptrs[N]; + this->compute_store_pointers(smem_ptrs); + sts(smem_ptrs, data, preds); + } + + // Store to the tile in shared memory. + template + inline __device__ void store(Store_type const (&data)[N], uint32_t preds) { + this->store(data, preds); + } + + // Store to the tile in shared memory. TODO: Remove last template arguments. + template + inline __device__ void store(void const* (&gmem_ptrs)[N], uint32_t (&preds)[M]) { + uint32_t smem_ptrs[N]; + this->compute_store_pointers(smem_ptrs); + ldgsts(smem_ptrs, gmem_ptrs, preds); + } + + // Store to the tile in shared memory. + template + inline __device__ void store(void const* (&gmem_ptrs)[N], uint32_t preds, uint64_t = 0) { + uint32_t tmp[1] = {preds}; + this->store(gmem_ptrs, tmp); + } + + // Store to the tile in shared memory. + template + inline __device__ void store(void const* (&gmem_ptrs)[N], uint32_t preds) { + uint32_t tmp[1] = {preds}; + this->store(gmem_ptrs, tmp); + } + + inline __device__ void add_smem_barrier_base(uint64_t*) {} + + // The shared memory pointer. + uint32_t smem_; + // The read offset. Reserve 4 offsets if needed. + int smem_read_offset_; + // The write offset. + int smem_write_offset_; + // The buffer base offset for read. + int smem_read_buffer_; + // The buffer base offset for write. + int smem_write_buffer_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Use TMA +template < + // The description of the tile computed by this CTA. + typename Cta_tile, + // The number of rows in the 2D shared memory buffer. + int M_, + // The number of cols. + int N_, + // The size in bits of each element. + int BITS_PER_ELEMENT_, + // The number of bytes per STS. Not relevant for TMA + int BYTES_PER_STS_, + // The number of buffers. (Used in multistage and double buffer cases.) + int BUFFERS_PER_TILE_, + // Do we enable the fast path for LDS.128 and friends. + int ENABLE_LDS_FAST_PATH_, + // The number of rows that are used for the XOR swizzling to allow fast STS/LDS. + int ROWS_PER_XOR_PATTERN_, + // The number of cols that are used for the XOR swizzling to allow fast STS/LDS. + int COLS_PER_XOR_PATTERN_, + // Use or not predicates + bool USE_PREDICATES_, + // The leading dim elements in shared memory + int LEAD_DIM_ELEMENTS_> +struct Smem_tile_without_skews + : public Smem_tile_without_skews { + // Base struct + using Base = + Smem_tile_without_skews; + static constexpr bool USE_TMA = true; + + // Tile size overrides. STS per thread not relevant for TMA + static constexpr int BYTES_PER_BUFFER = M_ * N_ * Base::BITS_PER_ELEMENT / 8; + static constexpr int BYTES_PER_TILE = BYTES_PER_BUFFER * Base::BUFFERS_PER_TILE; + static constexpr int BYTES_PER_TILE_INC_BOUNDARY = BYTES_PER_TILE - BYTES_PER_BUFFER; + // The number of bytes per barrier + static constexpr int BYTES_PER_BARRIER = 8; + + // Ctor + inline __device__ Smem_tile_without_skews(void* smem, int tidx) : Base(smem, tidx) { + this->smem_write_offset_ = __nvvm_get_smem_pointer(smem); + this->smem_barrier_offset_ = 0; + this->elect_one_ = elect_one_sync(); + } + + inline __device__ void add_smem_barrier_base(uint64_t* smem_barrier) { + this->smem_barrier_ = smem_barrier; + this->smem_barrier_offset_ = __nvvm_get_smem_pointer(this->smem_barrier_); + } + + /** + * \brief load tensor blocks from global memory and stores to shared memory using tma instructions + * + * \param p_desc pointer to tma descriptor masked as const void* pointer + * \param smem_offset shared memory offset in bytes relative to smem_write_buffer_ + * \param coord0 tensor access coordinate in dimension 1, used by tma load + * \param coord1 tensor access coordinate in dimension 2, used by tma load + * \param coord2 tensor access coordinate in dimension 3, used by tma load + * \param coord3 tensor access coordinate in dimension 4, used by tma load + * \param coord4 tensor access coordinate in dimension 5, used by tma load + * \param filter_offsets encodes multicast cta id and filter offsets + */ + template + inline __device__ void store(void const* p_desc, unsigned const& smem_offset, int32_t coord0, + int32_t coord1, int32_t coord2, int32_t coord3, int32_t coord4, + uint16_t filter_offsets, uint16_t mcast_cta_mask, + uint64_t mem_desc) { + uint32_t smem = this->smem_write_offset_ + smem_offset; + fmha::utmaldg( + reinterpret_cast(p_desc), smem, unsigned(this->smem_barrier_offset_), + coord0, coord1, coord2, coord3, coord4, filter_offsets, mcast_cta_mask, mem_desc, + this->elect_one_); + } + + // Same function as above but for runtime cga dimension + template + inline __device__ void store(void const* p_desc, unsigned const& smem_offset, int32_t coord0, + int32_t coord1, int32_t coord2, int32_t coord3, int32_t coord4, + uint16_t filter_offsets, uint16_t mcast_cta_mask, uint64_t mem_desc, + bool mcast_enabled) { + uint32_t smem = this->smem_write_offset_ + smem_offset; + fmha::utmaldg(reinterpret_cast(p_desc), smem, + unsigned(this->smem_barrier_offset_), coord0, coord1, coord2, + coord3, coord4, filter_offsets, mcast_cta_mask, mcast_enabled, + mem_desc, this->elect_one_); + } + + // Move the write offset to next buffer. + inline __device__ void move_next_write_buffer() { + if (Base::BUFFERS_PER_TILE > 1) { + this->smem_write_offset_ += (this->smem_write_offset_ >= BYTES_PER_TILE_INC_BOUNDARY) + ? -BYTES_PER_TILE_INC_BOUNDARY + : BYTES_PER_BUFFER; + this->smem_barrier_offset_ += + (this->smem_barrier_offset_ >= Base::BUFFERS_PER_TILE * BYTES_PER_BARRIER) + ? -Base::BUFFERS_PER_TILE * BYTES_PER_BARRIER + : BYTES_PER_BARRIER; + } + } + + inline __device__ void move_next_write_buffer(int buffer_id) { + if (Base::BUFFERS_PER_TILE > 1) { + this->smem_write_offset_ = this->smem_ + buffer_id * BYTES_PER_BUFFER; + } + this->smem_barrier_offset_ = __nvvm_get_smem_pointer(this->smem_barrier_ + buffer_id); + } + + // Move the read offset to next buffer. + // do nothing, as it is controlled by gmma desc + inline __device__ void move_next_read_buffer() {} + + uint64_t* smem_barrier_; + uint32_t smem_barrier_offset_; + // elect one thread to issue utmaldg + uint32_t elect_one_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits, + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The layout of the tile. + typename Layout, + // The size of the STS. + int BYTES_PER_STS = 16, + // The number of buffers per tile. + int BUFFERS_PER_TILE = 1, + // Use or not predicates + bool USE_PREDICATES = true> +struct Smem_tile_a {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Rows_per_xor_pattern_volta_a { + // The size in bits. + enum { N_IN_BITS = N * Traits::BITS_PER_ELEMENT_A }; + + // The number of rows. + enum { VALUE = N_IN_BITS <= 256 ? 1 : (N_IN_BITS <= 512 ? 2 : 4) }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Compute_reset_mask { + // The potential mask. + enum { HALF = MMAS_K_WITH_PADDING / 2 }; + + // The remainder. + enum { MOD = MMAS_K % HALF }; + + // The final value. + enum { VALUE = (MMAS_K == MOD ? 0 : HALF) | Compute_reset_mask::VALUE }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Compute_reset_mask<0, MMAS_K_WITH_PADDING> { + enum { VALUE = 0 }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Compute_reset_mask { + enum { VALUE = MMAS_K - 1 }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits, + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE, + // How many rows to use for the XOR pattern to avoid bank conflicts? + int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_volta_a::VALUE> +struct Smem_tile_volta_row_a + : public Smem_tile_without_skews { + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + // The base class. + using Base = Smem_tile_without_skews; + // The fragment. + using Fragment = Fragment_a; + + // When we use padding to reach a power of two, special care has to be taken. + using Cta_tile_with_padding = Cta_tile_with_k_with_padding; + // The number of MMAs. + using Mma_tile_with_padding = typename Traits::template Mma_tile; + + // The size of a single LDS in bytes. + enum { BYTES_PER_LDS = 16 }; + + // Ctor. + inline __device__ Smem_tile_volta_row_a(void* smem, int tidx) : Base(smem, tidx) { + // For documentation on the layout, see doc/xmma_smem_layout.xlsx. + + // The number of warps. + int const WARPS_M = Cta_tile::WARPS_M; + int const WARPS_N = Cta_tile::WARPS_N; + int const WARPS_K = Cta_tile::WARPS_K; + + // The masks to select the warps. + int const WARP_MASK_M = Warp_masks::M; + int const WARP_MASK_K = Warp_masks::K; + + // The divisor for the warps. + int const WARP_DIV_M = 1 * 1 * Cta_tile::THREADS_PER_WARP; + int const WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP; + + // The row and column read by the thread. + int smem_read_row, smem_read_col; + if (Base::N_WITH_PADDING >= 64) { + smem_read_row = (tidx & WARP_MASK_M) / WARP_DIV_M * Mma_tile::M_PER_MMA / 1 + + (tidx & 0x10) / 2 + (tidx & 0x07); + smem_read_col = (tidx & 0x03); + } else if (Base::N_WITH_PADDING == 32) { + smem_read_row = (tidx & WARP_MASK_M) / WARP_DIV_M * Mma_tile::M_PER_MMA / 2 + + (tidx & 0x10) / 4 + (tidx & 0x06) / 2; + smem_read_col = (tidx & 0x02) / 2 + (tidx & 0x01) * 4; + } else { + assert(false); + } + + // For WARPS_K > 1, we do not support Base::N_WITH_PADDING < 64 for the moment. + static_assert(WARPS_K <= 2 && (WARPS_K == 1 || Base::N_WITH_PADDING >= 64), ""); + + // We "swap" the block for the second warp working on the in-CTA split-K. + if (WARPS_K == 2) { + smem_read_col ^= (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile_with_padding::MMAS_K; + } + + // The shared memory offset. + this->smem_read_offset_ = smem_read_row * Base::BYTES_PER_ROW + smem_read_col * BYTES_PER_LDS; + } + + // Rewind smem_read_offset for last LDS phase in main loop.- + inline __device__ void reverse_smem_read_offset(int ki = 0) { + // Move the offset to the next position. See doc/xmma_smem_layout.xlsx. + this->smem_read_offset_ ^= ((ki % 2 == 0) ? 1 : 3) * BYTES_PER_LDS; + } + + // Load from shared memory. + inline __device__ void load(Fragment (&a)[Mma_tile::MMAS_M], int ki) { +#pragma unroll + for (int mi = 0; mi < Mma_tile::MMAS_M; ++mi) { + // Jump over as many rows as needed. + int offset = mi * Mma_tile::M_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING; + + // TODO: Could we fuse smem_read_buffer and smem_read_offset? + uint4 tmp; + lds(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset); + a[mi].reg(0) = tmp.x; + a[mi].reg(1) = tmp.y; + a[mi].reg(2) = tmp.z; + a[mi].reg(3) = tmp.w; + } + + // Move the offset to the next position. See doc/xmma_smem_layout.xlsx. + static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented"); + if (Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15) { + this->smem_read_offset_ ^= 31 * BYTES_PER_LDS; + } else if (Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7) { + this->smem_read_offset_ ^= 15 * BYTES_PER_LDS; + } else if (Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3) { + this->smem_read_offset_ ^= 7 * BYTES_PER_LDS; + } else if (Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1) { + this->smem_read_offset_ ^= 3 * BYTES_PER_LDS; + } else if (Mma_tile_with_padding::MMAS_K >= 2) { + this->smem_read_offset_ ^= 1 * BYTES_PER_LDS; + } + } + + // Reset the read offset. + inline __device__ void reset_read_offset() { + // The number of MMAs in the K dimension. + enum { MMAS_K = Mma_tile::MMAS_K }; + + // The number of MMAs in the K dimension when we include padding. + enum { MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K }; + + // Assemble the mask. + enum { MASK = Compute_reset_mask::VALUE }; + + // Reset the read offset. + this->smem_read_offset_ ^= MASK * BYTES_PER_LDS; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_a + : public Smem_tile_volta_row_a { + // The traits class. + using Traits = fmha::Volta_hmma_fp16_traits; + // The base class. + using Base = Smem_tile_volta_row_a; + + // Ctor. + inline __device__ Smem_tile_a(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Rows_per_xor_pattern_turing_a { + // The size in bits. + enum { N_IN_BITS = N * Traits::BITS_PER_ELEMENT_A }; + + // The number of rows. + enum { VALUE = N_IN_BITS <= 128 ? 1 : (N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8)) }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits, + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE, + // How many rows to use for the XOR pattern to avoid bank conflicts? + int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_turing_a::VALUE> +struct Smem_tile_turing_row_a + : public Smem_tile_without_skews { + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + // The base class. + using Base = + Smem_tile_without_skews; + // The fragment. + using Fragment = Fragment_a; + + // When we use padding to reach a power of two, special care has to be taken. + using Cta_tile_with_padding = Cta_tile_with_k_with_padding; + // The number of MMAs. + using Mma_tile_with_padding = typename Traits::template Mma_tile; + + // The size of a single LDS in bytes. + enum { BYTES_PER_LDS = 16 }; + + // Ctor. + inline __device__ Smem_tile_turing_row_a(void* smem, int tidx) : Base(smem, tidx) { + // For documentation on the layout, see doc/mma_smem_layout.xlsx. + + // The number of warps. + int const WARPS_M = Cta_tile::WARPS_M; + int const WARPS_N = Cta_tile::WARPS_N; + int const WARPS_K = Cta_tile::WARPS_K; + + // The masks to select the warps. + int const WARP_MASK_M = Warp_masks::M; + int const WARP_MASK_K = Warp_masks::K; + + // The divisor for the warps. + int const WARP_DIV_M = 1 * 1 * Cta_tile::THREADS_PER_WARP; + int const WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP; + + // The row and column read by the thread. + int smem_read_row, smem_read_col; + + static_assert(Base::ROWS_PER_XOR_PATTERN == 8 || Base::ROWS_PER_XOR_PATTERN == 4 || + Base::ROWS_PER_XOR_PATTERN == 2 || Base::ROWS_PER_XOR_PATTERN == 1, + ""); + + if (Base::ROWS_PER_XOR_PATTERN == 8) { + smem_read_row = (tidx & WARP_MASK_M) / WARP_DIV_M * Mma_tile::M_PER_MMA / 1 + (tidx & 0x0f); + smem_read_col = (tidx & 0x07); + } else if (Base::ROWS_PER_XOR_PATTERN == 4) { + smem_read_row = + (tidx & WARP_MASK_M) / WARP_DIV_M * Mma_tile::M_PER_MMA / 2 + (tidx & 0x0e) / 2; + smem_read_col = (tidx & 0x06) / 2 + (tidx & 0x01) * 4; + } else if (Base::ROWS_PER_XOR_PATTERN == 2) { + smem_read_row = + (tidx & WARP_MASK_M) / WARP_DIV_M * Mma_tile::M_PER_MMA / 4 + (tidx & 0x0c) / 4; + smem_read_col = (tidx & 0x04) / 4 + (tidx & 0x03) * 2; + } else if (Base::ROWS_PER_XOR_PATTERN == 1) { + smem_read_row = + (tidx & WARP_MASK_M) / WARP_DIV_M * Mma_tile::M_PER_MMA / 8 + (tidx & 0x1f) / 8; + smem_read_col = (tidx & 0x07); + } + + static_assert(WARPS_K <= 2, ""); + + // We "swap" the block for the second warp working on the in-CTA split-K. + if (WARPS_K == 2) { + smem_read_col ^= (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile_with_padding::MMAS_K; + } + + // The shared memory offset. + this->smem_read_offset_ = smem_read_row * Base::BYTES_PER_ROW + smem_read_col * BYTES_PER_LDS; + } + + // Rewind smem_read_offset for last LDS phase in main loop.- + inline __device__ void reverse_smem_read_offset(int ki = 0) { + // Move the offset to the next position. See doc/mma_smem_layout.xlsx. + this->smem_read_offset_ ^= ((ki % 2 == 0) ? 1 : 3) * BYTES_PER_LDS; + } + + // Load from shared memory. + inline __device__ void load(Fragment (&a)[Mma_tile::MMAS_M], int ki) { +#pragma unroll + for (int mi = 0; mi < Mma_tile::MMAS_M; ++mi) { + int offset = mi * Mma_tile::M_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING; + uint2 tmp; + ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset); + a[mi].reg(0) = tmp.x; + a[mi].reg(1) = tmp.y; + } + + // Move the offset to the next position. See doc/mma_smem_layout.xlsx. + static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented"); + if (Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15) { + this->smem_read_offset_ ^= 31 * BYTES_PER_LDS; + } else if (Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7) { + this->smem_read_offset_ ^= 15 * BYTES_PER_LDS; + } else if (Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3) { + this->smem_read_offset_ ^= 7 * BYTES_PER_LDS; + } else if (Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1) { + this->smem_read_offset_ ^= 3 * BYTES_PER_LDS; + } else if (Mma_tile_with_padding::MMAS_K >= 2) { + this->smem_read_offset_ ^= 1 * BYTES_PER_LDS; + } + } + + // Reset the read offset. + inline __device__ void reset_read_offset() { + // The number of MMAs in the K dimension. + enum { MMAS_K = Mma_tile::MMAS_K }; + + // The number of MMAs in the K dimension when we include padding. + enum { MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K }; + + // Assemble the mask. + enum { MASK = Compute_reset_mask::VALUE }; + + // Reset the read offset. + this->smem_read_offset_ ^= MASK * BYTES_PER_LDS; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_a + : public Smem_tile_turing_row_a { + // The traits class. + using Traits = Turing_hmma_fp16_traits; + // The base class. + using Base = Smem_tile_turing_row_a; + + // Ctor. + inline __device__ Smem_tile_a(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_a + : public Smem_tile_turing_row_a { + // The traits class. + using Traits = Turing_hmma_fp32_traits; + // The base class. + using Base = Smem_tile_turing_row_a; + + // Ctor. + inline __device__ Smem_tile_a(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_a + : public Smem_tile_turing_row_a { + // The traits class. + using Traits = Turing_imma_int8_int32_traits; + // The base class. + using Base = Smem_tile_turing_row_a; + + // Ctor. + inline __device__ Smem_tile_a(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Rows_per_xor_pattern_ampere_a { + // The size in bits. + enum { N_IN_BITS = N * Traits::BITS_PER_ELEMENT_A }; + + // The number of rows. + enum { VALUE = N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8) }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Rows_per_xor_pattern_ampere_row_a : public Rows_per_xor_pattern_ampere_a {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits, + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE, + // How many rows to use for the XOR pattern to avoid bank conflicts? + int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_ampere_row_a::VALUE> +struct Smem_tile_ampere_row_a + : public Smem_tile_without_skews { + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + // The base class. + using Base = + Smem_tile_without_skews; + // The fragment. + using Fragment = Fragment_a; + + // When we use padding to reach a power of two, special care has to be taken. + using Cta_tile_with_padding = Cta_tile_with_k_with_padding; + // The number of MMAs. + using Mma_tile_with_padding = typename Traits::template Mma_tile; + + // The size of a single LDS in bytes. + enum { BYTES_PER_LDS = 16 }; + + // Ctor. + inline __device__ Smem_tile_ampere_row_a(void* smem, int tidx) : Base(smem, tidx) { + // For documentation on the layout, see doc/mma_smem_layout.xlsx. + + // The number of warps. + int const WARPS_M = Cta_tile::WARPS_M; + int const WARPS_N = Cta_tile::WARPS_N; + int const WARPS_K = Cta_tile::WARPS_K; + + // The masks to select the warps. + int const WARP_MASK_M = Warp_masks::M; + int const WARP_MASK_K = Warp_masks::K; + + // The divisor for the warps. + int const WARP_DIV_M = 1 * 1 * Cta_tile::THREADS_PER_WARP; + int const WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP; + + // The row and column read by the thread. + int smem_read_row, smem_read_col; + + static_assert(Base::ROWS_PER_XOR_PATTERN == 8 || Base::ROWS_PER_XOR_PATTERN == 4 || + Base::ROWS_PER_XOR_PATTERN == 2, + ""); + + if (Base::ROWS_PER_XOR_PATTERN == 8) { + smem_read_row = (tidx & WARP_MASK_M) / WARP_DIV_M * Mma_tile::M_PER_MMA / 1 + (tidx & 0x0f); + smem_read_col = (tidx & 0x07); + smem_read_col ^= (tidx & 0x10) / 16; + } else if (Base::ROWS_PER_XOR_PATTERN == 4) { + smem_read_row = + (tidx & WARP_MASK_M) / WARP_DIV_M * Mma_tile::M_PER_MMA / 2 + (tidx & 0x0e) / 2; + smem_read_col = (tidx & 0x06) / 2 + (tidx & 0x01) * 4; + smem_read_col ^= (tidx & 0x10) / 16; + } else if (Base::ROWS_PER_XOR_PATTERN == 2) { + smem_read_row = + (tidx & WARP_MASK_M) / WARP_DIV_M * Mma_tile::M_PER_MMA / 4 + (tidx & 0x0c) / 4; + smem_read_col = (tidx & 0x04) / 4 + (tidx & 0x03) * 2; + smem_read_col ^= (tidx & 0x10) / 16; + } + + static_assert(WARPS_K <= 2, ""); + static_assert(WARPS_K != 2 || Base::ROWS_PER_XOR_PATTERN != 2, ""); + + // We "swap" the block for the second warp working on the same outputs in-CTA split-K. + if (WARPS_K == 2) { + smem_read_col ^= (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile_with_padding::MMAS_K * 2; + } + + // The shared memory offset. + this->smem_read_offset_ = smem_read_row * Base::BYTES_PER_ROW + smem_read_col * BYTES_PER_LDS; + } + + // Rewind smem_read_offset for last LDS phase in main loop. + inline __device__ void reverse_smem_read_offset(int ki = 0) { + // Undo the pointer increment for the next ni. + // Should match the load function below for ki = 0. + if (Mma_tile_with_padding::MMAS_K >= 2) { + this->smem_read_offset_ ^= BYTES_PER_LDS * 2; + } + } + + // Load from shared memory. + inline __device__ void load(Fragment (&a)[Mma_tile::MMAS_M], int ki) { + if (ki < Mma_tile::VALID_MMAS_K) { +#pragma unroll + for (int mi = 0; mi < Mma_tile::MMAS_M; ++mi) { + // Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows). + int offset = mi * Mma_tile::M_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING; + + // Load using LDSM.M88.4. + uint4 tmp; + ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset); + + // Store the value into the fragment. + a[mi].reg(0) = tmp.x; + a[mi].reg(1) = tmp.y; + a[mi].reg(2) = tmp.z; + a[mi].reg(3) = tmp.w; + } + } + + // Move the offset to the next position. See doc/mma_smem_layout.xlsx. + static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented"); + if (Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15) { + this->smem_read_offset_ ^= 31 * BYTES_PER_LDS * 2; + } else if (Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7) { + this->smem_read_offset_ ^= 15 * BYTES_PER_LDS * 2; + } else if (Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3) { + this->smem_read_offset_ ^= 7 * BYTES_PER_LDS * 2; + } else if (Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1) { + this->smem_read_offset_ ^= 3 * BYTES_PER_LDS * 2; + } else if (Mma_tile_with_padding::MMAS_K >= 2) { + this->smem_read_offset_ ^= 1 * BYTES_PER_LDS * 2; + } + } + + // Reset the read offset. + inline __device__ void reset_read_offset() { + // The number of MMAs in the K dimension. + enum { MMAS_K = Mma_tile::MMAS_K }; + + // The number of MMAs in the K dimension when we include padding. + enum { MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K }; + + // Assemble the mask. + enum { MASK = Compute_reset_mask::VALUE }; + + // Reset the read offset. + this->smem_read_offset_ ^= MASK * BYTES_PER_LDS * 2; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_a + : public Smem_tile_ampere_row_a { + // The traits class. + using Traits = Ampere_hmma_fp16_traits; + // The base class. + using Base = Smem_tile_ampere_row_a; + + // Ctor. + inline __device__ Smem_tile_a(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_a + : public Smem_tile_ampere_row_a { + // The traits class. + using Traits = Ampere_hmma_fp32_traits; + // The base class. + using Base = Smem_tile_ampere_row_a; + + // Ctor. + inline __device__ Smem_tile_a(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_a + : public Smem_tile_ampere_row_a { + // The traits class. + using Traits = Ampere_hmma_bf16_traits; + // The base class. + using Base = Smem_tile_ampere_row_a; + + // Ctor. + inline __device__ Smem_tile_a(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_a + : public Smem_tile_ampere_row_a { + // The traits class. + using Traits = Ampere_imma_int8_int32_traits; + // The base class. + using Base = Smem_tile_ampere_row_a; + + // Ctor. + inline __device__ Smem_tile_a(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_a + : public Smem_tile_ampere_row_a { + // The traits class. + using Traits = Ada_qmma_e4m3_fp32_traits; + // The base class. + using Base = Smem_tile_ampere_row_a; + + // Ctor. + inline __device__ Smem_tile_a(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_a + : public Smem_tile_ampere_row_a { + // The traits class. + using Traits = Ada_qmma_e4m3_fp16_traits; + // The base class. + using Base = Smem_tile_ampere_row_a; + + // Ctor. + inline __device__ Smem_tile_a(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits, + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The layout of the tile. + typename Layout, + // The size of the STS. + int BYTES_PER_STS = 16, + // The number of buffers per tile. + int BUFFERS_PER_TILE = 1, + // Use or not predicates + bool USE_PREDICATES = true> +struct Smem_tile_b {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Rows_per_xor_pattern_volta_b { + // The size in bits. + enum { N_IN_BITS = N * Traits::BITS_PER_ELEMENT_B }; + + // The number of rows. + enum { VALUE = N_IN_BITS <= 256 ? 1 : (N_IN_BITS <= 512 ? 2 : 4) }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits, + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE, + // How many rows to use for the XOR pattern to avoid bank conflicts? + int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_volta_b::VALUE> +struct Smem_tile_volta_col_b + : public Smem_tile_without_skews { + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + // The base class. + using Base = Smem_tile_without_skews; + + // When we use padding to reach a power of two, special care has to be taken. + using Cta_tile_with_padding = Cta_tile_with_k_with_padding; + // The number of MMAs. + using Mma_tile_with_padding = typename Traits::template Mma_tile; + // The fragment. + using Fragment = Fragment_b; + + // The size of a single LDS in bytes. + enum { BYTES_PER_LDS = 16 }; + + // Ctor. + inline __device__ Smem_tile_volta_col_b(void* smem, int tidx) : Base(smem, tidx) { + // For documentation on the layout, see doc/xmma_smem_layout.xlsx. + + // The number of warps. + int const WARPS_M = Cta_tile::WARPS_M; + int const WARPS_N = Cta_tile::WARPS_N; + int const WARPS_K = Cta_tile::WARPS_K; + + // The masks to select the warps. + int const WARP_MASK_N = Warp_masks::N; + int const WARP_MASK_K = Warp_masks::K; + + // The divisor for the warps. + int const WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP; + int const WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP; + + // The row and column read by the thread. + int smem_read_row, smem_read_col; + + if (Base::N_WITH_PADDING >= 64) { + smem_read_row = (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA / 1 + + (tidx & 0x18) / 2 + (tidx & 0x03); + smem_read_col = (tidx & 0x03); + } else if (Base::N_WITH_PADDING == 32) { + smem_read_row = (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA / 2 + + (tidx & 0x18) / 4 + (tidx & 0x02) / 2; + smem_read_col = (tidx & 0x02) / 2 + (tidx & 0x01) * 4; + } else { + assert(false); + } + + // For WARPS_K > 1, we do not support Base::N_WITH_PADDING < 64 for the moment. + static_assert(WARPS_K <= 2 && (WARPS_K == 1 || Base::N_WITH_PADDING >= 64), ""); + + // We "swap" the block for the second warp working on the in-CTA split-K. + if (WARPS_K == 2) { + smem_read_col ^= (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile_with_padding::MMAS_K; + } + + // The shared memory offset. + this->smem_read_offset_ = smem_read_row * Base::BYTES_PER_ROW + smem_read_col * BYTES_PER_LDS; + } + + // Rewind smem_read_offset for last LDS phase in main loop.- + inline __device__ void reverse_smem_read_offset(int ki = 0) { + // Move the offset to the next position. See doc/xmma_smem_layout.xlsx. + this->smem_read_offset_ ^= ((ki % 2 == 0) ? 1 : 3) * BYTES_PER_LDS; + } + + // Load from shared memory. + inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) { +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { + // Jump over as many rows as needed. + int offset = ni * Mma_tile::N_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING; + + // TODO: Can we fuse read_offset and read_buffer? + uint4 tmp; + lds(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset); + b[ni].reg(0) = tmp.x; + b[ni].reg(1) = tmp.y; + b[ni].reg(2) = tmp.z; + b[ni].reg(3) = tmp.w; + } + + // Move the offset to the next position. See doc/xmma_smem_layout.xlsx. + static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented"); + if (Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15) { + this->smem_read_offset_ ^= 31 * BYTES_PER_LDS; + } else if (Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7) { + this->smem_read_offset_ ^= 15 * BYTES_PER_LDS; + } else if (Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3) { + this->smem_read_offset_ ^= 7 * BYTES_PER_LDS; + } else if (Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1) { + this->smem_read_offset_ ^= 3 * BYTES_PER_LDS; + } else if (Mma_tile_with_padding::MMAS_K >= 2) { + this->smem_read_offset_ ^= 1 * BYTES_PER_LDS; + } + } + + // Reset the read offset. + inline __device__ void reset_read_offset() { + // The number of MMAs in the K dimension. + enum { MMAS_K = Mma_tile::MMAS_K }; + + // The number of MMAs in the K dimension when we include padding. + enum { MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K }; + + // Assemble the mask. + enum { MASK = Compute_reset_mask::VALUE }; + + // Reset the read offset. + this->smem_read_offset_ ^= MASK * BYTES_PER_LDS; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_b + : public Smem_tile_volta_col_b { + // The traits class. + using Traits = fmha::Volta_hmma_fp16_traits; + // The base class. + using Base = Smem_tile_volta_col_b; + + // Ctor. + inline __device__ Smem_tile_b(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Rows_per_xor_pattern_turing_b { + // The size in bits. + enum { N_IN_BITS = N * Traits::BITS_PER_ELEMENT_B }; + + // The number of rows. + enum { VALUE = N_IN_BITS <= 128 ? 1 : (N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8)) }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits, + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE, + // How many rows to use for the XOR pattern to avoid bank conflicts? + int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_turing_b::VALUE> +struct Smem_tile_turing_col_b + : public Smem_tile_without_skews { + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + // The base class. + using Base = + Smem_tile_without_skews; + // The fragment. + using Fragment = Fragment_b; + + // When we use padding to reach a power of two, special care has to be taken. + using Cta_tile_with_padding = Cta_tile_with_k_with_padding; + // The number of MMAs. + using Mma_tile_with_padding = typename Traits::template Mma_tile; + + // The size of a single LDS in bytes. + enum { BYTES_PER_LDS = 16 }; + + // Ctor. + inline __device__ Smem_tile_turing_col_b(void* smem, int tidx) : Base(smem, tidx) { + // For documentation on the layout, see doc/mma_smem_layout.xlsx. + + // The number of warps. + int const WARPS_M = Cta_tile::WARPS_M; + int const WARPS_N = Cta_tile::WARPS_N; + int const WARPS_K = Cta_tile::WARPS_K; + + // The masks to select the warps. + int const WARP_MASK_N = Warp_masks::N; + int const WARP_MASK_K = Warp_masks::K; + + // The divisor for the warps. + int const WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP; + int const WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP; + + // The row and column read by the thread. + int smem_read_row, smem_read_col; + + static_assert(Base::ROWS_PER_XOR_PATTERN == 8 || Base::ROWS_PER_XOR_PATTERN == 4 || + Base::ROWS_PER_XOR_PATTERN == 2 || Base::ROWS_PER_XOR_PATTERN == 1, + ""); + + if (Base::ROWS_PER_XOR_PATTERN == 8) { + // For group fprop. B is divided into 2 halves along N dimension. + // The fist warp takes the first half and the second warp takes the second half. + smem_read_row = (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA / 1 + (tidx & 0x0f); + smem_read_col = (tidx & 0x07); + } else if (Base::ROWS_PER_XOR_PATTERN == 4) { + smem_read_row = + (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA / 2 + (tidx & 0x0e) / 2; + smem_read_col = (tidx & 0x06) / 2 + (tidx & 0x01) * 4; + } else if (Base::ROWS_PER_XOR_PATTERN == 2) { + smem_read_row = + (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA / 4 + (tidx & 0x0c) / 4; + smem_read_col = (tidx & 0x04) / 4 + (tidx & 0x03) * 2; + } else if (Base::ROWS_PER_XOR_PATTERN == 1) { + smem_read_row = + (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA / 8 + (tidx & 0x1f) / 8; + smem_read_col = (tidx & 0x07); + } + + static_assert(WARPS_K <= 2, ""); + + // We "swap" the block for the second warp working on the in-CTA split-K. + if (WARPS_K == 2) { + smem_read_col ^= (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile_with_padding::MMAS_K; + } + + // The shared memory offset. + this->smem_read_offset_ = smem_read_row * Base::BYTES_PER_ROW + smem_read_col * BYTES_PER_LDS; + } + + // Rewind smem_read_offset for last LDS phase in main loop.- + inline __device__ void reverse_smem_read_offset(int ki = 0) { + // Move the offset to the next position. See doc/mma_smem_layout.xlsx. + this->smem_read_offset_ ^= ((ki % 2 == 0) ? 1 : 3) * BYTES_PER_LDS; + } + + // Load from shared memory. + inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) { +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { + int offset = ni * Mma_tile::N_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING; + uint2 tmp; + ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset); + b[ni].reg(0) = tmp.x; + b[ni].reg(1) = tmp.y; + } + // Move the offset to the next position. See doc/mma_smem_layout.xlsx. + static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented"); + if (Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15) { + this->smem_read_offset_ ^= 31 * BYTES_PER_LDS; + } else if (Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7) { + this->smem_read_offset_ ^= 15 * BYTES_PER_LDS; + } else if (Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3) { + this->smem_read_offset_ ^= 7 * BYTES_PER_LDS; + } else if (Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1) { + this->smem_read_offset_ ^= 3 * BYTES_PER_LDS; + } else if (Mma_tile_with_padding::MMAS_K >= 2) { + this->smem_read_offset_ ^= 1 * BYTES_PER_LDS; + } + } + + // Reset the read offset. + inline __device__ void reset_read_offset() { + // The number of MMAs in the K dimension. + enum { MMAS_K = Mma_tile::MMAS_K }; + + // The number of MMAs in the K dimension when we include padding. + enum { MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K }; + + // Assemble the mask. + enum { MASK = Compute_reset_mask::VALUE }; + + // Reset the read offset. + this->smem_read_offset_ ^= MASK * BYTES_PER_LDS; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_b + : public Smem_tile_turing_col_b { + // The traits class. + using Traits = Turing_hmma_fp16_traits; + // The base class. + using Base = Smem_tile_turing_col_b; + + // Ctor. + inline __device__ Smem_tile_b(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_b + : public Smem_tile_turing_col_b { + // The traits class. + using Traits = Turing_hmma_fp32_traits; + // The base class. + using Base = Smem_tile_turing_col_b; + + // Ctor. + inline __device__ Smem_tile_b(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_b + : public Smem_tile_turing_col_b { + // The traits class. + using Traits = Turing_imma_int8_int32_traits; + // The base class. + using Base = Smem_tile_turing_col_b; + + // Ctor. + inline __device__ Smem_tile_b(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Rows_per_xor_pattern_ampere_b { + // The size in bits. + enum { N_IN_BITS = N * Traits::BITS_PER_ELEMENT_B }; + + // The number of rows. + enum { VALUE = N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8) }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Rows_per_xor_pattern_ampere_col_b : public Rows_per_xor_pattern_ampere_b {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits, + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE, + // How many rows to use for the XOR pattern to avoid bank conflicts? + int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_ampere_col_b::VALUE> +struct Smem_tile_ampere_col_b + : public Smem_tile_without_skews { + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + // The base class. + using Base = + Smem_tile_without_skews; + // The fragment. + using Fragment = Fragment_b; + + // When we use padding to reach a power of two, special care has to be taken. + using Cta_tile_with_padding = Cta_tile_with_k_with_padding; + // The number of MMAs. + using Mma_tile_with_padding = typename Traits::template Mma_tile; + + // The size of a single LDS in bytes. + enum { BYTES_PER_LDS = 16 }; + + // The number of STS per thread + enum { STS_PER_THREAD_ = Base::ROWS * Base::THREADS_PER_ROW / Cta_tile::THREADS_PER_CTA }; + + // The number of STS per thread must be at least 1. + enum { STS_PER_THREAD = Max<1, STS_PER_THREAD_>::VALUE }; + + // Ctor. + inline __device__ Smem_tile_ampere_col_b(void* smem, int tidx) : Base(smem, tidx) { + // For documentation on the layout, see doc/mma_smem_layout.xlsx. + + // The number of warps. + int const WARPS_M = Cta_tile::WARPS_M; + int const WARPS_N = Cta_tile::WARPS_N; + int const WARPS_K = Cta_tile::WARPS_K; + + // The masks to select the warps. + int const WARP_MASK_N = Warp_masks::N; + int const WARP_MASK_K = Warp_masks::K; + + // The divisor for the warps. + int const WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP; + int const WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP; + + // The row and column read by the thread. + int smem_read_row, smem_read_col; + + static_assert(Base::ROWS_PER_XOR_PATTERN == 8 || Base::ROWS_PER_XOR_PATTERN == 4 || + Base::ROWS_PER_XOR_PATTERN == 2, + ""); + + if (Base::ROWS_PER_XOR_PATTERN == 8) { + // For group fprop. B is divided into 2 halves along N dimension. + // The fist warp takes the first half and the second warp takes the second half. + smem_read_row = (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA / 1 + (tidx & 0x07) + + (tidx & 0x10) / 2; + smem_read_col = (tidx & 0x07); + smem_read_col ^= (tidx & 0x08) / 8; + } else if (Base::ROWS_PER_XOR_PATTERN == 4) { + smem_read_row = (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA / 2 + + (tidx & 0x06) / 2 + (tidx & 0x10) / 4; + smem_read_col = (tidx & 0x06) / 2 + (tidx & 0x01) * 4; + smem_read_col ^= (tidx & 0x08) / 8; + } else if (Base::ROWS_PER_XOR_PATTERN == 2) { + smem_read_row = (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA / 4 + + (tidx & 0x04) / 4 + (tidx & 0x10) / 8; + smem_read_col = (tidx & 0x04) / 4 + (tidx & 0x03) * 2; + smem_read_col ^= (tidx & 0x08) / 8; + } + + static_assert(WARPS_K <= 2, ""); + static_assert(WARPS_K != 2 || Base::ROWS_PER_XOR_PATTERN != 2, ""); + + // We "swap" the block for the second warp working on the in-CTA split-K. + if (WARPS_K == 2) { + smem_read_col ^= (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile_with_padding::MMAS_K * 2; + } + + // The shared memory offset. + this->smem_read_offset_ = smem_read_row * Base::BYTES_PER_ROW + smem_read_col * BYTES_PER_LDS; + } + + // Rewind smem_read_offset for last LDS phase in main loop. + inline __device__ void reverse_smem_read_offset(int ki = 0) { + // Undo the pointer increment for the next ni. + // Should match the load function below for ki = 0. + if (Mma_tile_with_padding::MMAS_K >= 2) { + this->smem_read_offset_ ^= BYTES_PER_LDS * 2; + } + } + + // Load from shared memory. + inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) { + if (ki < Mma_tile::VALID_MMAS_K) { +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { + // Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows). + int offset = ni * Mma_tile::N_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING; + + // Load using LDSM.M88.4. + uint4 tmp; + ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset); + + // Store the value into the fragment. + b[ni].reg(0) = tmp.x; + b[ni].reg(1) = tmp.y; + b[ni].reg(2) = tmp.z; + b[ni].reg(3) = tmp.w; + } + } + + // Move the offset to the next position. See doc/mma_smem_layout.xlsx. + static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented"); + if (Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15) { + this->smem_read_offset_ ^= 31 * BYTES_PER_LDS * 2; + } else if (Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7) { + this->smem_read_offset_ ^= 15 * BYTES_PER_LDS * 2; + } else if (Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3) { + this->smem_read_offset_ ^= 7 * BYTES_PER_LDS * 2; + } else if (Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1) { + this->smem_read_offset_ ^= 3 * BYTES_PER_LDS * 2; + } else if (Mma_tile_with_padding::MMAS_K >= 2) { + this->smem_read_offset_ ^= 1 * BYTES_PER_LDS * 2; + } + } + + // Reset the read offset. + inline __device__ void reset_read_offset() { + // The number of MMAs in the K dimension. + enum { MMAS_K = Mma_tile::MMAS_K }; + + // The number of MMAs in the K dimension when we include padding. + enum { MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K }; + + // Assemble the mask. + enum { MASK = Compute_reset_mask::VALUE }; + + // Reset the read offset. + this->smem_read_offset_ ^= MASK * BYTES_PER_LDS * 2; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_b + : public Smem_tile_ampere_col_b { + // The traits class. + using Traits = Ampere_hmma_fp16_traits; + // The base class. + using Base = Smem_tile_ampere_col_b; + + // Ctor. + inline __device__ Smem_tile_b(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_b + : public Smem_tile_ampere_col_b { + // The traits class. + using Traits = Ampere_hmma_fp32_traits; + // The base class. + using Base = Smem_tile_ampere_col_b; + + // Ctor. + inline __device__ Smem_tile_b(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_b + : public Smem_tile_ampere_col_b { + // The traits class. + using Traits = Ampere_hmma_bf16_traits; + // The base class. + using Base = Smem_tile_ampere_col_b; + + // Ctor. + inline __device__ Smem_tile_b(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_b + : public Smem_tile_ampere_col_b { + // The traits class. + using Traits = Ampere_imma_int8_int32_traits; + // The base class. + using Base = Smem_tile_ampere_col_b; + + // Ctor. + inline __device__ Smem_tile_b(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_b + : public Smem_tile_ampere_col_b { + // The traits class. + using Traits = Ada_qmma_e4m3_fp32_traits; + // The base class. + using Base = Smem_tile_ampere_col_b; + + // Ctor. + inline __device__ Smem_tile_b(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_b + : public Smem_tile_ampere_col_b { + // The traits class. + using Traits = Ada_qmma_e4m3_fp16_traits; + // The base class. + using Base = Smem_tile_ampere_col_b; + + // Ctor. + inline __device__ Smem_tile_b(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Rows_per_xor_pattern_ampere_row_b : public Rows_per_xor_pattern_ampere_b {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits, + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE, + // How many rows to use for the XOR pattern to avoid bank conflicts? + int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_ampere_row_b::VALUE, + // How many cols to use for the XOR pattern to avoid bank conflicts? + int COLS_PER_XOR_PATTERN_ = 1> +struct Smem_tile_ampere_row_b + : public Smem_tile_without_skews { + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + // The base class. + using Base = Smem_tile_without_skews; + // The fragment. + using Fragment = Fragment_b; + + // Can we use LDSM? No if the data type is 32-bit large. + enum { USE_LDSMT = Traits::BITS_PER_ELEMENT_B == 16 }; + + // The size of a single LDS in bytes. + enum { BYTES_PER_LDS = USE_LDSMT ? 16 : 4 }; + + // The number of elements per LDS. + enum { ELEMENTS_PER_LDS = BYTES_PER_LDS * 8 / Traits::BITS_PER_ELEMENT_B }; + + // The number of STS per thread + enum { STS_PER_THREAD_ = Base::ROWS * Base::THREADS_PER_ROW / Cta_tile::THREADS_PER_CTA }; + + // The number of STS per thread must be at least 1. + enum { STS_PER_THREAD = Max<1, STS_PER_THREAD_>::VALUE }; + + // Ctor. + inline __device__ Smem_tile_ampere_row_b(void* smem, int tidx) : Base(smem, tidx) { + // For documentation on the layout, see doc/xmma_smem_layout.xlsx. + + // The number of warps. + int const WARPS_M = Cta_tile::WARPS_M; + int const WARPS_N = Cta_tile::WARPS_N; + int const WARPS_K = Cta_tile::WARPS_K; + + // The masks to select the warps. + int const WARP_MASK_N = Warp_masks::N; + int const WARP_MASK_K = Warp_masks::K; + + // The divisor for the warps. + int const WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP; + int const WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP; + + // The row/col read by the thread. + int smem_read_row, smem_read_col; + + static_assert((USE_LDSMT && Base::ROWS_PER_XOR_PATTERN == 8) || + Base::ROWS_PER_XOR_PATTERN == 4 || Base::ROWS_PER_XOR_PATTERN == 2, + ""); + + if (USE_LDSMT && Base::ROWS_PER_XOR_PATTERN == 8) { + // For group dgrad. B is divided into 2 halves along K dimension. + // The fist warp takes the first half and the second warp takes the second half. + smem_read_row = + (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile::MMAS_K * 16 + (tidx & 0x07) + (tidx & 0x08); + smem_read_col = (tidx & 0x07); + } else if (USE_LDSMT && Base::ROWS_PER_XOR_PATTERN == 4) { + smem_read_row = (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile::MMAS_K * 8 + (tidx & 0x06) / 2 + + (tidx & 0x08) / 2; + smem_read_col = (tidx & 0x01) * 4 + (tidx & 0x06) / 2; + } else if (USE_LDSMT && Base::ROWS_PER_XOR_PATTERN == 2) { + smem_read_row = (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile::MMAS_K * 4 + (tidx & 0x04) / 4 + + (tidx & 0x08) / 4; + smem_read_col = (tidx & 0x03) * 2 + (tidx & 0x04) / 4; + } else if (Base::ROWS_PER_XOR_PATTERN == 4 && Base::COLS_PER_XOR_PATTERN == 2) { + smem_read_row = (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile::MMAS_K * 8 + (tidx & 0x03); + smem_read_col = (tidx & 0x1c) / 4 + (tidx & 0x03) * 8; + } + + // Each half-warp applies a different XOR pattern -- see the Excel document. + if (USE_LDSMT) { + smem_read_col ^= (tidx & WARP_MASK_N) / WARP_DIV_N * 2 + (tidx & 0x10) / 16; + } else { + smem_read_col ^= (tidx & WARP_MASK_N) / WARP_DIV_N * 16; + } + + // The shared memory offset. + this->smem_read_offset_ = smem_read_row * Base::BYTES_PER_ROW + smem_read_col * BYTES_PER_LDS; + + // Fill zeroes for group conv + } + + // Rewind smem_read_offset for last LDS phase in main loop. + inline __device__ void reverse_smem_read_offset(int ki = 0) { + // The size of each element in bits. + int const BITS_PER_ELT = Traits::BITS_PER_ELEMENT_B; + // The size in bytes of the data needed to compute an MMA per CTA. + int const BYTES_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA * BITS_PER_ELT / 8; + +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { + // Undo the pointer increment for the next ni. + // Should match the load function below for ki = 0. + if (BYTES_PER_MMA_PER_CTA >= 128) { + // Nothing to do! + } else if (BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1) { + this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA; + } else if (BYTES_PER_MMA_PER_CTA == 64) { + // Nothing to do! + } else if (BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 4) { + this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6); + } else if (BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 2) { + this->smem_read_offset_ ^= BYTES_PER_LDS * 2; + } + } + + // Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels) + if (BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 && Mma_tile::MMAS_N % 2 == 1) { + this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA; + } + } + + // Load from shared memory. + inline __device__ void load(Fragment (&b)[Mma_tile::VALID_MMAS_N], int ki) { + // The size of each element in bits. + int const BITS_PER_ELT = Traits::BITS_PER_ELEMENT_B; + // The size in bytes of the data needed to compute an MMA per CTA. + int const BYTES_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA * BITS_PER_ELT / 8; + +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { + // Prepare the offset. + int offset = ki * Base::ROWS_PER_XOR_PATTERN * 2 * Base::BYTES_PER_ROW; + if (BYTES_PER_MMA_PER_CTA == 32) { + offset += this->smem_read_offset_; + } else if (BYTES_PER_MMA_PER_CTA == 64) { + offset += this->smem_read_offset_ + (ni / 2) * BYTES_PER_MMA_PER_CTA * 2; + } else { + offset += this->smem_read_offset_ + (ni)*BYTES_PER_MMA_PER_CTA; + } + + // Load the data using LDSM.MT88.2. + uint32_t ptr = this->smem_ + this->smem_read_buffer_ + offset; + + if (ni < Mma_tile::VALID_MMAS_N) { + uint4 tmp; + if (USE_LDSMT) { + ldsmt(tmp, ptr); + } else { + lds(tmp.x, (ptr) + 0 * Base::BYTES_PER_ROW); + lds(tmp.y, (ptr) + 4 * Base::BYTES_PER_ROW); + lds(tmp.z, (ptr ^ 32) + 0 * Base::BYTES_PER_ROW); + lds(tmp.w, (ptr ^ 32) + 4 * Base::BYTES_PER_ROW); + } + + // Store those values in the fragment. + b[ni].reg(0) = tmp.x; + b[ni].reg(1) = tmp.y; + b[ni].reg(2) = tmp.z; + b[ni].reg(3) = tmp.w; + } + + // static_assert(BYTES_PER_MMA_PER_CTA >= 128 || + // BYTES_PER_MMA_PER_CTA == 64 || + // (BYTES_PER_MMA_PER_CTA == 32 && + // (Mma_tile::MMAS_M == 4 || + // Mma_tile::MMAS_M == 2 || + // Mma_tile::MMAS_M == 1)), ""); + + // Move the pointer for the next ni. I expect the compiler to not recompute those. + if (BYTES_PER_MMA_PER_CTA >= 128) { + // Nothing to do! + } else if (BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1) { + this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA; + } else if (BYTES_PER_MMA_PER_CTA == 64) { + // Nothing to do! + } else if (BYTES_PER_MMA_PER_CTA == 32) { + if ((ni & 1) == 0) { + this->smem_read_offset_ ^= BYTES_PER_LDS * 2; + } else if (Mma_tile::MMAS_N >= 16 && (ni & 7) == 7) { + this->smem_read_offset_ ^= BYTES_PER_LDS * 30; + } else if (Mma_tile::MMAS_N >= 8 && (ni & 3) == 3) { + this->smem_read_offset_ ^= BYTES_PER_LDS * 14; + } else if (Mma_tile::MMAS_N >= 4 && (ni & 1) == 1) { + this->smem_read_offset_ ^= BYTES_PER_LDS * 6; + } + } + } + + // Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels) + if (BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 && Mma_tile::MMAS_N % 2 == 1) { + this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA; + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_b + : public Smem_tile_ampere_row_b { + // The traits class. + using Traits = Ampere_hmma_fp32_traits; + // The base class. + using Base = Smem_tile_ampere_row_b; + + // Ctor. + inline __device__ Smem_tile_b(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_b + : public Smem_tile_ampere_row_b { + // The traits class. + using Traits = Ampere_hmma_bf16_traits; + // The base class. + using Base = Smem_tile_ampere_row_b; + + // Ctor. + inline __device__ Smem_tile_b(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/smem_tile_o.h b/csrc/fmha_v2/fmha/smem_tile_o.h new file mode 100644 index 0000000000..af7311a111 --- /dev/null +++ b/csrc/fmha_v2/fmha/smem_tile_o.h @@ -0,0 +1,1646 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include +#include +#include +#include + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// H M M A +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_o {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_o { + // The instruction traits. + using Traits = Volta_hmma_fp16_16x16x16_traits; + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + // The accumulators. + using Accumulator = fmha::Fragment_accumulator; + // The accumulators. + using Data_type = typename Accumulator::Data_type; + + // The size of each element. + enum { BYTES_PER_ELEMENT = sizeof(Data_type) }; + + // The size of each STS. + enum { BYTES_PER_STS = 16 }; + + // The size of each row in shared memory. + enum { BYTES_PER_ROW = Cta_tile::N * Cta_tile::WARPS_K * 2 * BYTES_PER_ELEMENT }; + + // The size of each LDS. + enum { BYTES_PER_LDS = 16 }; + + // The number of threads (to produce 16B per LDS). + enum { THREADS_PER_ROW = Cta_tile::N * BYTES_PER_ELEMENT / BYTES_PER_LDS }; + + // The number of rows loaded per LDS. + enum { ROWS_PER_LDS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + + // The number of rows. + enum { ROWS = Cta_tile::M }; + + // We want at least one output per thread (if possible). + enum { ROWS_PER_LOOP_ = ROWS <= 64 ? ROWS : (int)Min::VALUE }; + + // We also want to have "complete" MMAs. + enum { ROWS_PER_LOOP = Max::VALUE }; + + // The number of outer loops. + enum { LOOPS = fmha::Div_up::VALUE }; + + // Make sure it matches our expectations. + static_assert(ROWS_PER_LOOP >= (int)Mma_tile::M_PER_MMA_PER_CTA, ""); + + // Do we have to guard against partial writes/reads. + enum { HAS_INCOMPLETE_LDS = ROWS_PER_LOOP % ROWS_PER_LDS != 0 }; + + // The total number of LDS per loop. + enum { LDS_PER_LOOP = fmha::Div_up::VALUE }; + + // The amount of shared memory. + enum { BYTES_PER_TILE = ROWS_PER_LOOP * BYTES_PER_ROW }; + + // Warps. + enum { WARPS_M = Cta_tile::WARPS_M }; + + enum { WARPS_N = Cta_tile::WARPS_N }; + + enum { WARPS_K = Cta_tile::WARPS_K }; + + // Determine the config. + enum { WARPS_2x1x2 = WARPS_M == 2 && WARPS_N == 1 && WARPS_K == 2 }; + + enum { WARPS_1x1x8 = WARPS_M == 1 && WARPS_N == 1 && WARPS_K == 8 }; + + enum { WARPS_1x1x4 = WARPS_M == 1 && WARPS_N == 1 && WARPS_K == 4 }; + + // Flash Attention uses WARPS_4x1x1 + enum { WARPS_4x1x1 = WARPS_M == 4 && WARPS_N == 1 && WARPS_K == 1 }; + + // Ctor. + inline __device__ Smem_tile_o(void* smem, int tidx) { + // Get a 32-bit value for the shared memory address. + uint32_t smem_ = __nvvm_get_smem_pointer(smem); + + // The row/col written by the thread. + int write_row, write_col; + + // SEQLEN == 128. Segments of 128B are written by 2 warps. + if (WARPS_2x1x2 && Cta_tile::N == 32) { + write_row = (tidx & 0x30) / 2 + (tidx & 0x07); + write_col = (tidx & 0x0f); + write_col ^= (tidx & 0x40) / 16; + + // SEQLEN == 128 and N == 64. + } else if (WARPS_2x1x2 && Cta_tile::N == 64) { + write_row = (tidx & 0x30) / 2 + (tidx & 0x07); + write_col = (tidx & 0x40) / 8 + (tidx & 0x08) * 2 + (tidx & 0x07); + + // SEQLEN == 256, 384 and N == 32. Segments of 128B are written by 2 warps. + } else if (WARPS_1x1x4 && Cta_tile::N == 32) { + write_row = (tidx & 0x10) / 2 + (tidx & 0x07); + write_col = (tidx & 0x40) / 8 + (tidx & 0x08) * 2 + (tidx & 0x07); + write_col ^= (tidx & 0x20) / 8; + + // SEQLEN == 256, 384 and N == 64. + } else if (WARPS_1x1x4 && Cta_tile::N == 64) { + write_row = (tidx & 0x10) / 2 + (tidx & 0x07); + write_col = (tidx & 0x60) / 4 + (tidx & 0x08) * 4 + (tidx & 0x07); + + // SEQLEN == 256, 384, 512 and N == 128. + } else if (WARPS_1x1x4 && Cta_tile::N == 128) { + write_row = (tidx & 0x10) / 2 + (tidx & 0x07); + write_col = (tidx & 0x60) / 2 + (tidx & 0x08) * 8 + (tidx & 0x07); + + // SEQLEN == 256, 384, 512 and N == 256. + } else if (WARPS_1x1x4 && Cta_tile::N == 256) { + write_row = (tidx & 0x10) / 2 + (tidx & 0x07); + write_col = (tidx & 0x60) / 1 + (tidx & 0x08) * 16 + (tidx & 0x07); + + // SEQLEN == 256, 384, 512 and N == 32. Segments of 128B are written by 2 warps. + } else if (WARPS_1x1x8 && Cta_tile::N == 32) { + write_row = (tidx & 0x10) / 2 + (tidx & 0x07); + write_col = (tidx & 0xc0) / 8 + (tidx & 0x08) * 4 + (tidx & 0x07); + write_col ^= (tidx & 0x20) / 8; + + // SEQLEN == 256, 384, 512 and N == 64. + } else if (WARPS_1x1x8 && Cta_tile::N == 64) { + write_row = (tidx & 0x10) / 2 + (tidx & 0x07); + write_col = (tidx & 0xe0) / 4 + (tidx & 0x08) * 8 + (tidx & 0x07); + + // ANY SEQLEN and N == 32 + } else if (WARPS_4x1x1 && Cta_tile::N == 32) { + write_row = (tidx & 0xf0) / 2 + (tidx & 0x07); + write_col = (tidx & 0x07); + write_col ^= (tidx & 0x08) / 2; + + // ANY SEQLEN and N == 64 + } else if (WARPS_4x1x1 && Cta_tile::N == 64) { + write_row = (tidx & 0x70) / 2 + (tidx & 0x07); + write_col = (tidx & 0x0f); + + // ANY SEQLEN and N == 128 + } else if (WARPS_4x1x1 && Cta_tile::N == 128) { + write_row = (tidx & 0x70) / 2 + (tidx & 0x07); + write_col = (tidx & 0x08) + (tidx & 0x0f); + + // ANY SEQLEN and N == 256 + } else if (WARPS_4x1x1 && Cta_tile::N == 256) { + write_row = (tidx & 0x70) / 2 + (tidx & 0x07); + write_col = (tidx & 0x08) * 3 + (tidx & 0x0f); + + // Not supported. + } else { + assert(false); + } + + // Assemble the write pointer. + smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS; + + // The element read by each thread. + int read_row = tidx / THREADS_PER_ROW; + int read_col = tidx % THREADS_PER_ROW; + + // Take the XOR pattern into account for the column. + read_col ^= read_row & 0x7; + + // Assemble the read pointer. + smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; + + // Is that thread active on the last LDS? + if (HAS_INCOMPLETE_LDS) { + is_active_for_last_lds_ = read_row + (LDS_PER_LOOP - 1) * ROWS_PER_LDS < Cta_tile::M; + } + } + + // Load the output fragments. + inline __device__ void load(uint4 (&out)[LDS_PER_LOOP]) const { + uint32_t local_smem_read_ = smem_read_; +#pragma unroll + for (int ii = 0; ii < LDS_PER_LOOP; ++ii) { + // Apply the XOR pattern if needed. (XOR 8 default) + if (ROWS_PER_LDS < 8) { + local_smem_read_ = (smem_read_ ^ ((ii * ROWS_PER_LDS) % 8 * BYTES_PER_LDS)); + } + + // Load the elements before the reduction (split-K). + uint4 tmp[Cta_tile::WARPS_K * 2]; +#pragma unroll + for (int jj = 0; jj < Cta_tile::WARPS_K * 2; ++jj) { + // The immediate. + int imm = ii * ROWS_PER_LDS * BYTES_PER_ROW; + if (Cta_tile::N == 256) { + imm += jj * 512; + } else if (Cta_tile::N == 128) { + imm += jj * 256; + } else if (Cta_tile::N == 64) { + imm += jj * 128; + } else if (Cta_tile::N == 32) { + imm += jj / 2 * 128; + } else { + assert(false); + } + + // The XOR mask. + int smem_read_offset = local_smem_read_; + if (Cta_tile::N == 32 && (jj % 2) == 1) { + smem_read_offset ^= 64; + } + + // Load... + if (!HAS_INCOMPLETE_LDS || (ii < LDS_PER_LOOP - 1 || is_active_for_last_lds_)) { + fmha::lds(tmp[jj], smem_read_offset + imm); + } + } + + // Perform the reduction. + out[ii] = tmp[0]; +#pragma unroll + for (int jj = 1; jj < Cta_tile::WARPS_K * 2; ++jj) { + out[ii] = fmha::hadd8(out[ii], tmp[jj]); + } + } + } + + // Store the accumulators. + template + inline __device__ void store(Accumulator const (&acc)[M][N], int mi) { + enum { M_PER_MMA = Mma_tile::M_PER_MMA_PER_CTA }; + +#pragma unroll + for (int ni = 0; ni < Mma_tile::VALID_MMAS_N; ++ni) { + // The number of MMAs that are stored per loop iteration. + enum { MMAS_M_PER_LOOP = Mma_tile::MMAS_M / LOOPS }; + +// Store 1st column of the different MMAs. +#pragma unroll + for (int mj = 0; mj < MMAS_M_PER_LOOP; ++mj) { + // Assemble the vectors for the stores. See how we swizzle the registers. + uint4 tmp_0; + tmp_0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(0); + tmp_0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(1); + tmp_0.z = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(4); + tmp_0.w = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(5); + + uint4 tmp_1; + tmp_1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(2); + tmp_1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(3); + tmp_1.z = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(6); + tmp_1.w = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(7); + + // Precompute the immediates to jump to the correct row. + int row = mj * M_PER_MMA * BYTES_PER_ROW; + + // The columns. + int smem_write_0 = smem_write_ ^ ((2 * ni + 0) * BYTES_PER_STS); + int smem_write_1 = smem_write_ ^ ((2 * ni + 1) * BYTES_PER_STS); + + // Store. + fmha::sts(smem_write_0 + row, tmp_0); + fmha::sts(smem_write_1 + row, tmp_1); + } + } + } + + // The write pointer. + uint32_t smem_write_; + // The write pointer. + uint32_t smem_read_; + // Is the thread active for the last LDS of the series? + int is_active_for_last_lds_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// This class converts the FP16/FP32 inputs to FP16x2. + +struct Convert_from_fp16 { + // Convert one pair of fp16 numbers. + template + static inline __device__ uint32_t convert(Accumulators const& acc, int ii) { + // Extract the 2x FP16 numbers (packed in a register). + uint32_t h2 = acc.reg(ii); + + return h2; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Convert_from_fp32 { + // Convert one pair of fp16 numbers. + template + inline __device__ uint32_t convert(Accumulators const& acc, int ii) { + // Extract the 2x floats. + float f0 = acc.elt(ii * 2 + 0); + float f1 = acc.elt(ii * 2 + 1); + + // Convert to FP16x2. + return fmha::float2_to_half2(f0, f1); + } + + // The bf16 accumulators (convert from fp32 to 2xbf16). + using Ampere_bf16_Accumulator = fmha::Fragment_accumulator; + + static inline __device__ uint32_t convert(Ampere_bf16_Accumulator const& acc, int ii) { + // Extract the 2x floats. + float f0 = acc.elt(ii * 2 + 0); + float f1 = acc.elt(ii * 2 + 1); + + // Convert to FP16x2. + return fmha::float2_to_bf16_x2(f0, f1); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hmma_smem_tile_o { + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + // The accumulators. + using Accumulator = fmha::Fragment_accumulator; + // The data type. + using Data_type = typename Accumulator::Data_type; + // The epilogue data type + using Epilogue_type = typename Traits::Epilogue_type; + + // The size of each element. + enum { BYTES_PER_ELEMENT = sizeof(Epilogue_type) }; + + // The amount of bytes per row (without packing or split-k). + enum { BYTES_PER_ROW = Cta_tile::N * BYTES_PER_ELEMENT }; + + // The size of each STS. + enum { BYTES_PER_STS = BYTES_PER_STS_ }; + + // The size of each LDS. + enum { BYTES_PER_LDS = 16 }; + + // The number of threads (to produce 16B per LDS). + enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDS }; + + // The number of rows loaded per LDS. + enum { ROWS_PER_LDS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + + // The number of rows in shared memory. + enum { ROWS = Cta_tile::M }; + + // We want at least one output per thread (if possible). + enum { ROWS_PER_LOOP_ = ROWS <= 64 ? ROWS : (int)Min::VALUE }; + + // We also want to have "complete" MMAs. + enum { ROWS_PER_LOOP = Max::VALUE }; + + // The number of outer loops. + enum { LOOPS = fmha::Div_up::VALUE }; + + // Make sure it matches our expectations. + static_assert(ROWS_PER_LOOP >= (int)Mma_tile::M_PER_MMA_PER_CTA, ""); + + // Do we have to guard against partial writes/reads. + enum { HAS_INCOMPLETE_LDS = ROWS_PER_LOOP % ROWS_PER_LDS != 0 }; + + // The total number of LDS per loop. + enum { LDS_PER_LOOP = fmha::Div_up::VALUE }; + + // The amount of shared memory. + enum { BYTES_PER_TILE = ROWS_PER_LOOP * BYTES_PER_ROW * Cta_tile::WARPS_K }; + + // The amount of row packing to make sure we have at least 128B per smem row (without split-k). + enum { ROW_PACKING = Max<1, 128 / BYTES_PER_ROW>::VALUE }; + + // Make sure our row packing is correct + static_assert(ROWS_PER_LOOP % ROW_PACKING == 0, ""); + + // The amount of shared memory per row after packing. + enum { BYTES_PER_ROW_WITH_PACKING = BYTES_PER_ROW * ROW_PACKING }; + + // Make sure we have at least 128B per row after packing. + static_assert(BYTES_PER_ROW_WITH_PACKING >= 128, ""); + + // The number of threads per row after packing. + enum { THREADS_PER_ROW_WITH_PACKING = THREADS_PER_ROW * ROW_PACKING }; + + // Make sure we have at least 8 threads per row after packing. + static_assert(THREADS_PER_ROW_WITH_PACKING >= 8, ""); + + // Warps. + enum { WARPS_M = Cta_tile::WARPS_M }; + + enum { WARPS_N = Cta_tile::WARPS_N }; + + enum { WARPS_K = Cta_tile::WARPS_K }; + + // Determine the config. + enum { WARPS_2x1x2 = WARPS_M == 2 && WARPS_N == 1 && WARPS_K == 2 }; + + enum { WARPS_1x1x8 = WARPS_M == 1 && WARPS_N == 1 && WARPS_K == 8 }; + + enum { WARPS_1x1x4 = WARPS_M == 1 && WARPS_N == 1 && WARPS_K == 4 }; + + // Flash Attention uses WARPS_4x1x1 + enum { WARPS_4x1x1 = WARPS_M == 4 && WARPS_N == 1 && WARPS_K == 1 }; + + enum { WARPS_4x1x2 = WARPS_M == 4 && WARPS_N == 1 && WARPS_K == 2 }; + + // Ctor. + inline __device__ Hmma_smem_tile_o(void* smem, int tidx) { + // Get a 32-bit value for the shared memory address. + uint32_t smem_ = __nvvm_get_smem_pointer(smem); + + // The row/col written by the thread. + int write_row, write_col; + + // SEQLEN == 128 and HIDDEN_SIZE_PER_HEAD == 16. + if (WARPS_2x1x2 && Cta_tile::N == 16) { + write_row = (tidx & 0x20) / 8 + (tidx & 0x10) / 16; + write_col = (tidx & 0x40) / 2 + (tidx & 0x0c) * 2 + (tidx & 0x03); + write_col ^= (tidx & 0x10) / 4; + + // SEQLEN == 128 and HIDDEN_SIZE_PER_HEAD == 32. + } else if (WARPS_2x1x2 && Cta_tile::N == 32) { + write_row = (tidx & 0x20) / 4 + (tidx & 0x18) / 8; + write_col = (tidx & 0x40) / 2 + (tidx & 0x04) * 4 + (tidx & 0x03); + write_col ^= (tidx & 0x18) / 2; + + // SEQLEN == 128 and HIDDEN_SIZE_PER_HEAD == 64. + } else if (WARPS_2x1x2 && Cta_tile::N == 64) { + write_row = (tidx & 0x20) / 2 + (tidx & 0x1c) / 4; + write_col = (tidx & 0x40) / 2 + (tidx & 0x03); + write_col ^= (tidx & 0x1c); + + // SEQLEN == 128 and HIDDEN_SIZE_PER_HEAD == 128. + } else if (WARPS_2x1x2 && Cta_tile::N == 128) { + write_row = (tidx & 0x20) / 2 + (tidx & 0x1c) / 4; + write_col = (tidx & 0x40) / 1 + (tidx & 0x1f); + + // SEQLEN == 256, 384, 512 and HIDDEN_SIZE_PER_HEAD == 16. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 16) { + write_row = (tidx & 0x10) / 16; + write_col = (tidx & 0x0c) * 2 + (tidx & 0xe3); + write_col ^= (tidx & 0x10) / 4; + + // SEQLEN == 256, 384, 512 and HIDDEN_SIZE_PER_HEAD == 32. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 32) { + write_row = (tidx & 0x18) / 8; + write_col = (tidx & 0x04) * 4 + (tidx & 0xe3); + write_col ^= (tidx & 0x18) / 2; + + // SEQLEN == 256, 384 and HIDDEN_SIZE_PER_HEAD == 64. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 64) { + write_row = (tidx & 0x1c) / 4; + write_col = (tidx & 0xff); + + // SEQLEN == 256, 384 and HIDDEN_SIZE_PER_HEAD == 128. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 128) { + write_row = (tidx & 0x1c) / 4; + write_col = (tidx & 0xe0) * 2 + (tidx & 0x1f); + + // SEQLEN == 256, 384 and HIDDEN_SIZE_PER_HEAD == 256. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 256) { + write_row = (tidx & 0x1c) / 4; + write_col = (tidx & 0xe0) * 4 + (tidx & 0x1f); + + // ANY SEQLEN and HIDDEN_SIZE_PER_HEAD == 16. + } else if (WARPS_4x1x1 && Cta_tile::N == 16) { + write_row = (tidx & 0xe0) / 8 + (tidx & 0x10) / 16; + write_col = (tidx & 0x0c) * 2 + (tidx & 0x03); + write_col ^= (tidx & 0x10) / 4; + + // ANY SEQLEN and HIDDEN_SIZE_PER_HEAD == 32. + } else if (WARPS_4x1x1 && Cta_tile::N == 32) { + write_row = (tidx & 0xe0) / 4 + (tidx & 0x18) / 8; + write_col = (tidx & 0x04) * 4 + (tidx & 0x03); + write_col ^= (tidx & 0x18) / 2; + + // ANY SEQLEN and HIDDEN_SIZE_PER_HEAD == 64/128. + } else if (WARPS_4x1x1 && (Cta_tile::N == 64 || Cta_tile::N == 128)) { + write_row = (tidx & 0x60) / 2 + (tidx & 0x1c) / 4; + write_col = (tidx & 0x1f); + + // ANY SEQLEN and HIDDEN_SIZE_PER_HEAD == 256. + } else if (WARPS_4x1x1 && (Cta_tile::N == 256 || Cta_tile::N == 512)) { + write_row = (tidx & 0x60) / 2 + (tidx & 0x1c) / 4; + write_col = (tidx & 0x1f); + + // GMMA: S=284/512 and HIDDEN_SIZE_PER_HEAD == 64. + } else if (WARPS_4x1x2 && Cta_tile::N == 64) { + write_row = (tidx & 0x60) / 2 + (tidx & 0x1c) / 4; + write_col = (tidx & 0x80) / 4 + (tidx & 0x03); + write_col ^= (tidx & 0x1c); + + // GMMA: S=284/512 and HIDDEN_SIZE_PER_HEAD == 64. + } else if (WARPS_4x1x2 && Cta_tile::N == 32) { + write_row = (tidx & 0x60) / 4 + (tidx & 0x1c) / 8; + write_col = (tidx & 0x80) / 4 + (tidx & 0x04) * 4 + (tidx & 0x03); + write_col ^= (tidx & 0x18) / 2; + + // Not supported. + } else { + assert(false); + } + + // Assemble the write pointer. + smem_write_ = smem_ + write_row * BYTES_PER_ROW_WITH_PACKING * Cta_tile::WARPS_K + + write_col * BYTES_PER_STS; + + // The element read by each thread. + int read_row = tidx / THREADS_PER_ROW; + int read_col = tidx % THREADS_PER_ROW; + + // Is that thread active on the last LDS? + if (HAS_INCOMPLETE_LDS) { + is_active_for_last_lds_ = read_row + (LDS_PER_LOOP - 1) * ROWS_PER_LDS < ROWS_PER_LOOP; + } + + // The XOR params. + int const XOR_MOD = 8 / ROW_PACKING; + + // Take the XOR pattern and the packing into account for the column. + read_col += read_row % ROW_PACKING * XOR_MOD; + read_row /= ROW_PACKING; + read_col ^= read_row % XOR_MOD; + + // Assemble the read pointer. + smem_read_ = smem_ + read_row * BYTES_PER_ROW_WITH_PACKING * Cta_tile::WARPS_K + + read_col * BYTES_PER_LDS; + } + + // Load the output fragments. + inline __device__ void load(uint4 (&out)[LDS_PER_LOOP]) const { + uint32_t local_smem_read_ = smem_read_; +#pragma unroll + for (int ii = 0; ii < LDS_PER_LOOP; ++ii) { + // Apply the XOR pattern if needed. (XOR 8 default) + if (ROWS_PER_LDS < 8) { + local_smem_read_ = (smem_read_ ^ ((ii * ROWS_PER_LDS) % 8 * BYTES_PER_LDS)); + } + + // Load the elements before the reduction (split-K). + uint4 tmp[Cta_tile::WARPS_K]; +#pragma unroll + for (int jj = 0; jj < Cta_tile::WARPS_K; ++jj) { + // Note: ROWS_PER_LDS does not take packing into account - hence BYTES_PER_ROW. + int imm = + ii * ROWS_PER_LDS * BYTES_PER_ROW * Cta_tile::WARPS_K + jj * BYTES_PER_ROW_WITH_PACKING; + + // Load... + if (!HAS_INCOMPLETE_LDS || (ii < LDS_PER_LOOP - 1 || is_active_for_last_lds_)) { + fmha::lds(tmp[jj], local_smem_read_ + imm); + } + } + + // Perform the reduction. + out[ii] = tmp[0]; +#pragma unroll + for (int jj = 1; jj < Cta_tile::WARPS_K; ++jj) { + out[ii] = fmha::add8(out[ii], tmp[jj]); + } + } + } + + // Store the accumulators. + template + inline __device__ void store_(Accumulators const (&acc)[M][N], int mi) { + enum { M_PER_MMA = Mma_tile::M_PER_MMA_PER_CTA }; + + Converter converter; +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { + // The number of MMAs that are stored per loop iteration. + enum { MMAS_M_PER_LOOP = Mma_tile::MMAS_M / LOOPS }; + + // Store 1st column of the different MMAs. + // Skip N paddings + if (ni < Mma_tile::VALID_MMAS_N) { +#pragma unroll + for (int mj = 0; mj < MMAS_M_PER_LOOP; ++mj) { + // Precompute the immediates to jump between rows. + int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW * Cta_tile::WARPS_K; + int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW * Cta_tile::WARPS_K; + + // The values (2 halves per register). + uint32_t h0 = converter.convert(acc[mi * MMAS_M_PER_LOOP + mj][ni], 0); + uint32_t h1 = converter.convert(acc[mi * MMAS_M_PER_LOOP + mj][ni], 1); + + // Store to shared memory. + fmha::sts(smem_write_ + row_0, h0); + fmha::sts(smem_write_ + row_1, h1); + } + } + + // Swizzle the write pointer using a XOR of 16B. + smem_write_ ^= 16; + + // Store 2nd column of the different MMAs. + // Skip N paddings + if (ni < Mma_tile::VALID_MMAS_N) { +#pragma unroll + for (int mj = 0; mj < MMAS_M_PER_LOOP; ++mj) { + // Precompute the immediates to jump between rows. + int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW * Cta_tile::WARPS_K; + int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW * Cta_tile::WARPS_K; + + // The values (2 halves per register). + uint32_t h2 = converter.convert(acc[mi * MMAS_M_PER_LOOP + mj][ni], 2); + uint32_t h3 = converter.convert(acc[mi * MMAS_M_PER_LOOP + mj][ni], 3); + + // Store to shared memory. + fmha::sts(smem_write_ + row_0, h2); + fmha::sts(smem_write_ + row_1, h3); + } + } + + // Cancel the previous XOR of 1 + swizzle the write pointer using a XOR of 32B or 64B. + if (ROW_PACKING == 4) { + smem_write_ ^= 16; + } else if (ROW_PACKING == 2) { + smem_write_ ^= 3 * 16; + } else if (ROW_PACKING == 1) { + //         7 + //       /    \ + //      3      3 + //    /  \    /  \ + //   1    1  1    1 + static_assert(Mma_tile::MMAS_N <= 64, ""); + if (Mma_tile::MMAS_N >= 32 && ni % 16 == 15) { + smem_write_ ^= 63 * 16; + } else if (Mma_tile::MMAS_N >= 16 && ni % 8 == 7) { + smem_write_ ^= 31 * 16; + } else if (Mma_tile::MMAS_N >= 8 && ni % 4 == 3) { + smem_write_ ^= 15 * 16; + } else if (Mma_tile::MMAS_N >= 4 && ni % 2 == 1) { + smem_write_ ^= 7 * 16; + } else if (Mma_tile::MMAS_N >= 2) { + smem_write_ ^= 3 * 16; + } + } else { + assert(false); + } + } + } + + // The write pointer. + uint32_t smem_write_; + // The write pointer. + uint32_t smem_read_; + // Is the thread active for the last LDS of the series? + int is_active_for_last_lds_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_o + : public Hmma_smem_tile_o { + // The traits class. + using Traits = fmha::Turing_hmma_fp16_traits; + // The base class. + using Base = Hmma_smem_tile_o; + + // The FP16 accumulators. + using Accumulators_fp16 = fmha::Fragment_accumulator; + // The FP32 accumulators. + using Accumulators_fp32 = fmha::Fragment_accumulator; + + // Ctor. + inline __device__ Smem_tile_o(void* smem, int tidx) : Base(smem, tidx) {} + + // Store from FP16 accumulators. That's the default. + template + inline __device__ void store(Accumulators_fp16 const (&acc)[M][N], int mi) { + this->template store_(acc, mi); + } + + // Store from FP32 accumulators. Special trick for the Flash-attention kernel. + // Convert from fp32 to fp16 before STS + template + inline __device__ void store(Accumulators_fp32 const (&acc)[M][N], int mi) { + this->template store_(acc, mi); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_o + : public Hmma_smem_tile_o { + // The traits class. + using Traits = fmha::Ampere_hmma_fp16_traits; + // The base class. + using Base = Hmma_smem_tile_o; + + // The FP16 accumulators. + using Accumulators_fp16 = fmha::Fragment_accumulator; + // The FP32 accumulators. + using Accumulators_fp32 = fmha::Fragment_accumulator; + + // Ctor. + inline __device__ Smem_tile_o(void* smem, int tidx) : Base(smem, tidx) {} + + // Store from FP16 accumulators. That's the default. + template + inline __device__ void store(Accumulators_fp16 const (&acc)[M][N], int mi) { + this->template store_(acc, mi); + } + + // Store from FP32 accumulators. Special trick for the Flash-attention kernel. + // Convert from fp32 to fp16 before STS + template + inline __device__ void store(Accumulators_fp32 const (&acc)[M][N], int mi) { + this->template store_(acc, mi); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_o + : public Hmma_smem_tile_o { + // The traits class. + using Traits = fmha::Ampere_hmma_bf16_bf16_traits; + // The base class. + using Base = Hmma_smem_tile_o; + + // The FP32 accumulators (only FP32 acc is supported for BF16 MMA). + using Accumulators_bf16 = fmha::Fragment_accumulator; + + // Ctor. + inline __device__ Smem_tile_o(void* smem, int tidx) : Base(smem, tidx) {} + + // Store from FP32 accumulators. Special trick for the Flash-attention kernel. + // Convert from fp32 to bf16 before STS + template + inline __device__ void store(Accumulators_bf16 const (&acc)[M][N], int mi) { + this->template store_(acc, mi); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_o + : public Hmma_smem_tile_o { + // The traits class. + using Traits = fmha::Ampere_hmma_fp32_traits; + // The base class. + using Base = Hmma_smem_tile_o; + // The MMA tile. + using Mma_tile = typename Base::Mma_tile; + // The accumulators. + using Accumulator = typename Base::Accumulator; + + // The size of each + enum { BYTES_PER_ELEMENT = Base::BYTES_PER_ELEMENT }; + + // The size of each row in shared memory. + enum { BYTES_PER_ROW = Base::BYTES_PER_ROW * Cta_tile::WARPS_K }; + + // The size of each row in shared memory. + enum { BYTES_PER_LDS = Base::BYTES_PER_LDS }; + + // The number of threads (to produce 16B per LDS). + enum { THREADS_PER_ROW = Base::THREADS_PER_ROW }; + + // The number of outer loops. + enum { LOOPS = Base::LOOPS }; + + // The number of rows loaded per LDS. + enum { ROWS_PER_LDS = Base::ROWS_PER_LDS }; + + // Do we have to guard against partial writes/reads. + enum { HAS_INCOMPLETE_LDS = Base::HAS_INCOMPLETE_LDS }; + + // The total number of LDS per loop. + enum { LDS_PER_LOOP = Base::LDS_PER_LOOP }; + + // Ctor. + inline __device__ Smem_tile_o(void* smem, int tidx) : Base(smem, tidx) { + // Get a 32-bit value for the shared memory address. + uint32_t smem_ = __nvvm_get_smem_pointer(smem); + + // The element read by each thread. + int read_row = tidx / THREADS_PER_ROW; + int read_col = tidx % THREADS_PER_ROW; + + // Take the XOR pattern into account for the column. + read_col ^= (read_row & 0x7) * 2; + + // Assemble the read pointer. + this->smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; + + // Is that thread active on the last LDS? + if (HAS_INCOMPLETE_LDS) { + this->is_active_for_last_lds_ = read_row + (LDS_PER_LOOP - 1) * ROWS_PER_LDS < Cta_tile::M; + } + } + + // Load the output fragments. + inline __device__ void load(uint4 (&out)[LDS_PER_LOOP]) const { +#pragma unroll + for (int ii = 0; ii < LDS_PER_LOOP; ++ii) { + // Load the elements before the reduction (split-K). + uint4 tmp[Cta_tile::WARPS_K]; +#pragma unroll + for (int jj = 0; jj < Cta_tile::WARPS_K; ++jj) { + int imm = ii * ROWS_PER_LDS * BYTES_PER_ROW + jj * Cta_tile::N * BYTES_PER_ELEMENT; + int is_valid = ii < LDS_PER_LOOP - 1 || this->is_active_for_last_lds_; + if (!HAS_INCOMPLETE_LDS || is_valid) { + fmha::lds(tmp[jj], this->smem_read_ + imm); + } + } + + // Perform the reduction. + out[ii] = tmp[0]; +#pragma unroll + for (int jj = 1; jj < Cta_tile::WARPS_K; ++jj) { + out[ii] = fmha::fadd4(out[ii], tmp[jj]); + } + } + } + + // Store the accumulators. + template + inline __device__ void store(Accumulator const (&acc)[M][N], int mi) { + enum { M_PER_MMA = Mma_tile::M_PER_MMA_PER_CTA }; + +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { + // The number of MMAs that are stored per loop iteration. + enum { MMAS_M_PER_LOOP = Mma_tile::MMAS_M / LOOPS }; + + // Store 1st column of the different MMAs. + if (ni < Mma_tile::VALID_MMAS_N) { +#pragma unroll + for (int mj = 0; mj < MMAS_M_PER_LOOP; ++mj) { + // Precompute the immediates to jump between rows. + int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW; + int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW; + + // Pack vectors. + uint2 tmp0; + tmp0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(0); + tmp0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(1); + + uint2 tmp1; + tmp1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(2); + tmp1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(3); + + // Store. + fmha::sts(this->smem_write_ + row_0, tmp0); + fmha::sts(this->smem_write_ + row_1, tmp1); + } + } + + // Swizzle the write pointer using a XOR of 16B. + this->smem_write_ ^= 32; + + // Store 2nd column of the different MMAs. + if (ni < Mma_tile::VALID_MMAS_N) { +#pragma unroll + for (int mj = 0; mj < MMAS_M_PER_LOOP; ++mj) { + // Precompute the immediates to jump between rows. + int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW; + int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW; + + uint2 tmp0, tmp1; + tmp0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(4); + tmp0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(5); + + tmp1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(6); + tmp1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(7); + + // Store. + fmha::sts(this->smem_write_ + row_0, tmp0); + fmha::sts(this->smem_write_ + row_1, tmp1); + } + } + + // Cancel the previous XOR of 1 + swizzle the write pointer using a XOR of 32B or 64B. + static_assert(Mma_tile::MMAS_N <= 16, ""); + if (Mma_tile::MMAS_N >= 16 && (ni & 7) == 7) { + this->smem_write_ ^= 31 * 32; + } else if (Mma_tile::MMAS_N >= 8 && (ni & 3) == 3) { + this->smem_write_ ^= 15 * 32; + } else if (Mma_tile::MMAS_N >= 4 && (ni & 1) == 1) { + this->smem_write_ ^= 7 * 32; + } else if ((ni & 1) == 0) { + this->smem_write_ ^= 3 * 32; + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_o + : public Hmma_smem_tile_o { + // The traits class. + using Traits = fmha::Ampere_hmma_bf16_traits; + // The base class. + using Base = Hmma_smem_tile_o; + // The MMA tile. + using Mma_tile = typename Base::Mma_tile; + // The accumulators. + using Accumulator = typename Base::Accumulator; + + // The size of each element. + enum { BYTES_PER_ELEMENT = Base::BYTES_PER_ELEMENT }; + + // The size of each row in shared memory. + enum { BYTES_PER_ROW = Base::BYTES_PER_ROW * Cta_tile::WARPS_K }; + + // The size of each row in shared memory. + enum { BYTES_PER_LDS = Base::BYTES_PER_LDS }; + + // The number of threads (to produce 16B per LDS). + enum { THREADS_PER_ROW = Base::THREADS_PER_ROW }; + + // The number of outer loops. + enum { LOOPS = Base::LOOPS }; + + // The number of rows loaded per LDS. + enum { ROWS_PER_LDS = Base::ROWS_PER_LDS }; + + // Do we have to guard against partial writes/reads. + enum { HAS_INCOMPLETE_LDS = Base::HAS_INCOMPLETE_LDS }; + + // The total number of LDS per loop. + enum { LDS_PER_LOOP = Base::LDS_PER_LOOP }; + + // Ctor. + inline __device__ Smem_tile_o(void* smem, int tidx) : Base(smem, tidx) { + // Get a 32-bit value for the shared memory address. + uint32_t smem_ = __nvvm_get_smem_pointer(smem); + + // The element read by each thread. + int read_row = tidx / THREADS_PER_ROW; + int read_col = tidx % THREADS_PER_ROW; + + // Take the XOR pattern into account for the column. + read_col ^= (read_row & 0x7) * 2; + + // Assemble the read pointer. + this->smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; + + // Is that thread active on the last LDS? + if (HAS_INCOMPLETE_LDS) { + this->is_active_for_last_lds_ = read_row + (LDS_PER_LOOP - 1) * ROWS_PER_LDS < Cta_tile::M; + } + } + + // Load the output fragments. + inline __device__ void load(uint4 (&out)[LDS_PER_LOOP]) const { +#pragma unroll + for (int ii = 0; ii < LDS_PER_LOOP; ++ii) { + // Load the elements before the reduction (split-K). + uint4 tmp[Cta_tile::WARPS_K]; +#pragma unroll + for (int jj = 0; jj < Cta_tile::WARPS_K; ++jj) { + int imm = ii * ROWS_PER_LDS * BYTES_PER_ROW + jj * Cta_tile::N * BYTES_PER_ELEMENT; + int is_valid = ii < LDS_PER_LOOP - 1 || this->is_active_for_last_lds_; + if (!HAS_INCOMPLETE_LDS || is_valid) { + fmha::lds(tmp[jj], this->smem_read_ + imm); + } + } + + // Perform the reduction. + out[ii] = tmp[0]; +#pragma unroll + for (int jj = 1; jj < Cta_tile::WARPS_K; ++jj) { + out[ii] = fmha::fadd4(out[ii], tmp[jj]); + } + } + } + + // Store the accumulators. + template + inline __device__ void store(Accumulator const (&acc)[M][N], int mi) { + enum { M_PER_MMA = Mma_tile::M_PER_MMA_PER_CTA }; + +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { + // The number of MMAs that are stored per loop iteration. + enum { MMAS_M_PER_LOOP = Mma_tile::MMAS_M / LOOPS }; + + // Store 1st column of the different MMAs. + if (ni < Mma_tile::VALID_MMAS_N) { +#pragma unroll + for (int mj = 0; mj < MMAS_M_PER_LOOP; ++mj) { + // Precompute the immediates to jump between rows. + int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW; + int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW; + + // Pack vectors. + uint2 tmp0; + tmp0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(0); + tmp0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(1); + + uint2 tmp1; + tmp1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(2); + tmp1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(3); + + // Store. + fmha::sts(this->smem_write_ + row_0, tmp0); + fmha::sts(this->smem_write_ + row_1, tmp1); + } + } + + // Swizzle the write pointer using a XOR of 16B. + this->smem_write_ ^= 32; + + // Store 2nd column of the different MMAs. + if (ni < Mma_tile::VALID_MMAS_N) { +#pragma unroll + for (int mj = 0; mj < MMAS_M_PER_LOOP; ++mj) { + // Precompute the immediates to jump between rows. + int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW; + int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW; + + uint2 tmp0, tmp1; + tmp0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(4); + tmp0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(5); + + tmp1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(6); + tmp1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(7); + + // Store. + fmha::sts(this->smem_write_ + row_0, tmp0); + fmha::sts(this->smem_write_ + row_1, tmp1); + } + } + + // Cancel the previous XOR of 1 + swizzle the write pointer using a XOR of 32B or 64B. + static_assert(Mma_tile::MMAS_N <= 16, ""); + if ((ni & 1) == 0) { + this->smem_write_ ^= 3 * 32; + } else if (Mma_tile::MMAS_N >= 16 && (ni & 7) == 7) { + this->smem_write_ ^= 31 * 32; + } else if (Mma_tile::MMAS_N >= 8 && (ni & 3) == 3) { + this->smem_write_ ^= 15 * 32; + } else if (Mma_tile::MMAS_N >= 4 && (ni & 1) == 1) { + this->smem_write_ ^= 7 * 32; + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// I M M A +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// each thread holds 8 accumulator registers per 16x16 MMA, representing a 2x4 tile +template +struct Regs_to_rows { + template + static inline __device__ void extract(Acc const& acc, uint4& row0, uint4& row1) { + // Volta/Turing: row-major + uint32_t tmp_00 = acc.reg(0); + uint32_t tmp_01 = acc.reg(2); + uint32_t tmp_02 = acc.reg(1); + uint32_t tmp_03 = acc.reg(3); + uint32_t tmp_10 = acc.reg(4); + uint32_t tmp_11 = acc.reg(6); + uint32_t tmp_12 = acc.reg(5); + uint32_t tmp_13 = acc.reg(7); + + row0.x = tmp_00; + row0.y = tmp_01; + row0.z = tmp_02; + row0.w = tmp_03; + + row1.x = tmp_10; + row1.y = tmp_11; + row1.z = tmp_12; + row1.w = tmp_13; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Regs_to_rows_8bit { + template + static inline __device__ void extract(Acc const& acc, uint4& row0, uint4& row1) { + // Ampere: col-major + uint32_t tmp_00 = acc.reg(0); + uint32_t tmp_01 = acc.reg(4); + uint32_t tmp_02 = acc.reg(1); + uint32_t tmp_03 = acc.reg(5); + uint32_t tmp_10 = acc.reg(2); + uint32_t tmp_11 = acc.reg(6); + uint32_t tmp_12 = acc.reg(3); + uint32_t tmp_13 = acc.reg(7); + + row0.x = tmp_00; + row0.y = tmp_01; + row0.z = tmp_02; + row0.w = tmp_03; + + row1.x = tmp_10; + row1.y = tmp_11; + row1.z = tmp_12; + row1.w = tmp_13; + } +}; + +template <> +struct Regs_to_rows : public Regs_to_rows_8bit {}; + +template <> +struct Regs_to_rows : public Regs_to_rows_8bit {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Regs_to_rows { + template + static inline __device__ void extract(Acc const& acc, uint2& row0, uint2& row1) { + uint16_t* row0_ptr = reinterpret_cast(&row0); + uint16_t* row1_ptr = reinterpret_cast(&row1); + row0_ptr[0] = acc.u16(0); + row0_ptr[1] = acc.u16(4); + row0_ptr[2] = acc.u16(1); + row0_ptr[3] = acc.u16(5); + + row1_ptr[0] = acc.u16(2); + row1_ptr[1] = acc.u16(6); + row1_ptr[2] = acc.u16(3); + row1_ptr[3] = acc.u16(7); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void add4(uint4& dst, uint4 const& src) { + reinterpret_cast(dst.x) += reinterpret_cast(src.x); + reinterpret_cast(dst.y) += reinterpret_cast(src.y); + reinterpret_cast(dst.z) += reinterpret_cast(src.z); + reinterpret_cast(dst.w) += reinterpret_cast(src.w); +} + +template +inline __device__ void add_vec(uint4& dst, uint4 const& src) { + add4(dst, src); +} + +template <> +inline __device__ void add_vec(uint4& dst, uint4 const& src) { + dst = fmha::hadd8(dst, src); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// The base class for 32-bit/16-bit accumulator types of imma/qmma. +// TODO Can we port Ampere hmma fp32 to this? +template +struct Smem_tile_o_base_8bit_mma { + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + // The accumulators. + using Accumulator = fmha::Fragment_accumulator; + + // The size of each element. + enum { BYTES_PER_ELEMENT = sizeof(typename Traits::Accumulator_type) }; + + // The amount of bytes per row (without packing or split-k). + enum { BYTES_PER_ROW = Cta_tile::N * BYTES_PER_ELEMENT }; + + // The size of each STS. + enum { BYTES_PER_STS = BYTES_PER_ELEMENT * 4 }; + + // The STS Packed Data Type + using Sts_packed_type = typename Uint_from_size_in_bytes::Type; + + // The size of each LDS. + enum { BYTES_PER_LDS = 16 }; + + // The number of threads to store a "row" of the matrix. We force it to 16 for SEQLEN=384. + enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDS }; + + // The number of rows loaded per LDS. + enum { ROWS_PER_LDS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + + // The STS bytes for one quad of threads + enum { BYTES_PER_STS_PER_QUAD = BYTES_PER_STS * 4 }; + + // The xor factor per LDS + // (4 consecutive threads do 64B swizzle for 16B per sts, 32B swizzle for 8B per sts) + enum { XOR_FACTOR = fmha::Div_up::VALUE }; + + // The smem offset in bytes per MMA_N (2 squad threads) + enum { BYTES_OFFSET_PER_MMA_N = BYTES_PER_STS * 8 }; + + // The number of "rows" to process in total. + enum { ROWS = Cta_tile::M }; + + // We want at least one output per thread (if possible). + enum { ROWS_PER_LOOP_ = ROWS <= 64 ? ROWS : (int)Min::VALUE }; + + // We also want to have "complete" MMAs. + enum { ROWS_PER_LOOP = Max::VALUE }; + + // The number of outer loops. + enum { LOOPS = fmha::Div_up::VALUE }; + + // Make sure it matches our expectations. + static_assert(ROWS_PER_LOOP >= (int)Mma_tile::M_PER_MMA_PER_CTA, ""); + + // Do we have to guard against partial writes/reads. + enum { HAS_INCOMPLETE_LDS = ROWS_PER_LOOP % ROWS_PER_LDS != 0 }; + + // The total number of LDS per loop. + enum { LDS_PER_LOOP = fmha::Div_up::VALUE }; + + // The amount of shared memory. + enum { BYTES_PER_TILE = ROWS_PER_LOOP * BYTES_PER_ROW * Cta_tile::WARPS_K }; + + // The amount of row packing to make sure we have at least 128B per smem row (without split-k). + enum { ROW_PACKING = Max<1, 128 / BYTES_PER_ROW>::VALUE }; + + // Make sure our row packing is correct + static_assert(ROWS_PER_LOOP % ROW_PACKING == 0, ""); + + // The amount of shared memory per row after packing. + enum { BYTES_PER_ROW_WITH_PACKING = BYTES_PER_ROW * ROW_PACKING }; + + // Make sure we have at least 128B per row after packing. + static_assert(BYTES_PER_ROW_WITH_PACKING >= 128, ""); + + // The number of threads per row after packing. + enum { THREADS_PER_ROW_WITH_PACKING = THREADS_PER_ROW * ROW_PACKING }; + + // Make sure we have at least 8 threads per row after packing. + static_assert(THREADS_PER_ROW_WITH_PACKING >= 8, ""); + + // Warps. + enum { WARPS_M = Cta_tile::WARPS_M }; + + enum { WARPS_N = Cta_tile::WARPS_N }; + + enum { WARPS_K = Cta_tile::WARPS_K }; + + static_assert(WARPS_K > 1 || std::is_same::value, + "Kernel misconfigured. No split-k needed."); + + // Determine the config. + enum { WARPS_2x1x2 = WARPS_M == 2 && WARPS_N == 1 && WARPS_K == 2 }; + + enum { WARPS_4x1x2 = WARPS_M == 4 && WARPS_N == 1 && WARPS_K == 2 }; + + enum { WARPS_1x1x8 = WARPS_M == 1 && WARPS_N == 1 && WARPS_K == 8 }; + + enum { WARPS_1x1x4 = WARPS_M == 1 && WARPS_N == 1 && WARPS_K == 4 }; + + enum { WARPS_4x1x1 = WARPS_M == 4 && WARPS_N == 1 && WARPS_K == 1 }; + + // Ctor. + inline __device__ Smem_tile_o_base_8bit_mma(void* smem, int tidx) { + // Get a 32-bit value for the shared memory address. + uint32_t smem_ = __nvvm_get_smem_pointer(smem); + + // The row/col written by the thread. + int write_row, write_col; + + // SEQLEN == 128 and HIDDEN_SIZE_PER_HEAD == 16. + if (WARPS_2x1x2 && Cta_tile::N == 16) { + write_row = (tidx & 0x20) / 4 + (tidx & 0x1e) / 8; + write_col = (tidx & 0x40) / 8 + (tidx & 0x07); + + // SEQLEN == 128 and HIDDEN_SIZE_PER_HEAD == 32. + } else if (WARPS_2x1x2 && Cta_tile::N == 32) { + write_row = (tidx & 0x20) / 2 + (tidx & 0x1c) / 4; + write_col = (tidx & 0x40) / 8 + (tidx & 0x07); + + // SEQLEN == 128 and HIDDEN_SIZE_PER_HEAD == 64. + } else if (WARPS_2x1x2 && Cta_tile::N == 64) { + write_row = (tidx & 0x20) / 2 + (tidx & 0x1c) / 4; + write_col = (tidx & 0x40) / 4 + (tidx & 0x07); + + // SEQLEN == 256, 384, 512 and HIDDEN_SIZE_PER_HEAD == 16. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 16) { + write_row = (tidx & 0x18) / 8; + write_col = (tidx & 0xe0) / 4 + (tidx & 0x07); + + // SEQLEN == 256, 384, 512 and HIDDEN_SIZE_PER_HEAD == 32. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 32) { + write_row = (tidx & 0x1c) / 4; + write_col = (tidx & 0xe0) / 4 + (tidx & 0x07); + + // SEQLEN == 256, 384 and HIDDEN_SIZE_PER_HEAD == 64. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 64) { + write_row = (tidx & 0x1c) / 4; + write_col = (tidx & 0xe0) / 2 + (tidx & 0x07); + + // GMMA: HIDDEN_SIZE_PER_HEAD == 64. + } else if (WARPS_4x1x2 && Cta_tile::N == 64) { + write_row = (tidx & 0x60) / 2 + (tidx & 0x1c) / 4; + write_col = (tidx & 0x80) / 8 + (tidx & 0x07); + + // Ada e4m3_fp32 + } else if (WARPS_4x1x1) { + write_row = (tidx & 0x60) / 2 + (tidx & 0x1c) / 4; + write_col = (tidx & 0x80) / 8 + (tidx & 0x07); + + // Not supported. + } else { + assert(false); + } + + // Assemble the write pointer. + smem_write_ = smem_ + write_row * BYTES_PER_ROW_WITH_PACKING * Cta_tile::WARPS_K + + write_col * BYTES_PER_STS; + + // The element read by each thread. + int read_row = tidx / THREADS_PER_ROW; + int read_col = tidx % THREADS_PER_ROW; + + // Is that thread active on the last LDS? + if (HAS_INCOMPLETE_LDS) { + is_active_for_last_lds_ = read_row + (LDS_PER_LOOP - 1) * ROWS_PER_LDS < ROWS_PER_LOOP; + } + + // The XOR params. + constexpr int XOR_MOD = 2 / ROW_PACKING; + + // Take the XOR pattern and the packing into account for the column. + read_col += read_row % ROW_PACKING * XOR_FACTOR; + read_row /= ROW_PACKING; + read_col ^= (read_row % XOR_MOD) * XOR_FACTOR; + + // Assemble the read pointer. + smem_read_ = smem_ + read_row * BYTES_PER_ROW_WITH_PACKING * Cta_tile::WARPS_K + + read_col * BYTES_PER_LDS; + } + + // Load the output fragments. + inline __device__ void load(uint4 (&out)[LDS_PER_LOOP]) const { +#pragma unroll + for (int ii = 0; ii < LDS_PER_LOOP; ++ii) { + // Load the elements before the reduction (split-K). + uint4 tmp[Cta_tile::WARPS_K]; +#pragma unroll + for (int jj = 0; jj < Cta_tile::WARPS_K; ++jj) { + // Note: ROWS_PER_LDS does not take packing into account - hence BYTES_PER_ROW. + int imm = + ii * ROWS_PER_LDS * BYTES_PER_ROW * Cta_tile::WARPS_K + jj * BYTES_PER_ROW_WITH_PACKING; + + // Load... + if (!HAS_INCOMPLETE_LDS || (ii < LDS_PER_LOOP - 1 || is_active_for_last_lds_)) { + fmha::lds(tmp[jj], smem_read_ + imm); + } + } + +// Perform the reduction. +#pragma unroll + for (int jj = 1; jj < Cta_tile::WARPS_K; ++jj) { + add_vec(tmp[0], tmp[jj]); + } + + // Write to out. + out[ii] = tmp[0]; + } + } + + // Store the accumulators. + template + inline __device__ void store(Accumulator const (&acc)[M][N], int mi) { + enum { M_PER_MMA = Mma_tile::M_PER_MMA_PER_CTA }; + + // The number of MMAs that are stored per loop iteration. + enum { MMAS_M_PER_LOOP = Mma_tile::MMAS_M / LOOPS }; + +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { +#pragma unroll + for (int mj = 0; mj < MMAS_M_PER_LOOP; ++mj) { + Sts_packed_type row_0, row_1; + Regs_to_rows::extract(acc[mi * MMAS_M_PER_LOOP + mj][ni], row_0, row_1); + + /* + (32bit acc) Each thread of a quad writes 16B per STS -> 64B per store. + Account for 2 -> 128B. + (16bit acc) Each thread of a quad writes 8B per STS -> 32B per store. + Account for 2 -> 64B. + */ + int imm_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW * Cta_tile::WARPS_K + + (ni / 2) * BYTES_OFFSET_PER_MMA_N; + int imm_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW * Cta_tile::WARPS_K + + (ni / 2) * BYTES_OFFSET_PER_MMA_N; + + // Store the elements. + fmha::sts(this->smem_write_ + imm_0, row_0); + fmha::sts(this->smem_write_ + imm_1, row_1); + } + // (32bit acc) Each thread of a quad writes 16B per STS -> 64B per store. + // (16bit acc) Each thread of a quad writes 8B per STS -> 32B per store. + if (Mma_tile::MMAS_N == 1) { + // Noop. + } else if (Mma_tile::MMAS_N % 2 == 0) { + this->smem_write_ ^= BYTES_PER_STS_PER_QUAD; + } else { + assert(false && "Unsupported"); + } + } + } + + // The write pointer. + uint32_t smem_write_; + // The write pointer. + uint32_t smem_read_; + // Is the thread active for the last LDS of the series? + int is_active_for_last_lds_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_o + : public Smem_tile_o_base_8bit_mma { + // The traits class. + using Traits = fmha::Volta_imma_int8_int32_traits; + // The base class. + using Base = Smem_tile_o_base_8bit_mma; + + // Ctor. + inline __device__ Smem_tile_o(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_o + : public Smem_tile_o_base_8bit_mma { + // The traits class. + using Traits = fmha::Turing_imma_int8_int32_traits; + // The base class. + using Base = Smem_tile_o_base_8bit_mma; + + // Ctor. + inline __device__ Smem_tile_o(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_o + : public Smem_tile_o_base_8bit_mma { + // The traits class. + using Traits = fmha::Ampere_imma_int8_int32_traits; + // The base class. + using Base = Smem_tile_o_base_8bit_mma; + + // Ctor. + inline __device__ Smem_tile_o(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_o + : public Smem_tile_o_base_8bit_mma { + // The traits class. + using Traits = fmha::Ada_qmma_e4m3_fp32_traits; + // The base class. + using Base = Smem_tile_o_base_8bit_mma; + + // Ctor. + inline __device__ Smem_tile_o(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_o + : public Smem_tile_o_base_8bit_mma { + // The traits class. + using Traits = fmha::Ada_qmma_e4m3_fp16_traits; + // The base class. + using Base = Smem_tile_o_base_8bit_mma; + + // Ctor. + inline __device__ Smem_tile_o(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_o_interleaved { + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + // The accumulators. + using Accumulator = fmha::Fragment_accumulator; + + enum { VEC = 32 }; + + enum { NUM_SLICES = Cta_tile::N / VEC }; + + static_assert(NUM_SLICES == 1 || NUM_SLICES == 2, ""); + + enum { BYTES_PER_ELEMENT = 4 }; + + enum { BYTES_PER_STS = 16 }; + + enum { BYTES_PER_LDS = 16 }; + + enum { ELTS_PER_STS = BYTES_PER_STS / BYTES_PER_ELEMENT }; + + static_assert(VEC * BYTES_PER_ELEMENT == 128, ""); + + enum { BYTES_PER_ROW = Cta_tile::WARPS_K * VEC * BYTES_PER_ELEMENT }; + + // Each row only stores one slice. The other slice starts this many rows below + enum { ROWS_PER_SLICE = Cta_tile::WARPS_M * 16 }; + + enum { TOTAL_ROWS = NUM_SLICES * ROWS_PER_SLICE }; + + enum { BYTES_PER_TILE = BYTES_PER_ROW * TOTAL_ROWS }; + + // LDS + enum { THREADS_PER_ROW = 8 }; + + enum { ROWS_PER_LDS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + + enum { LDS_PER_LOOP = TOTAL_ROWS / ROWS_PER_LDS }; + + // Ctor. + inline __device__ Smem_tile_o_interleaved(void* smem, int tidx) { + smem_ = __nvvm_get_smem_pointer(smem); + + constexpr int WARPS_M = Cta_tile::WARPS_M; + constexpr int WARPS_N = Cta_tile::WARPS_N; + constexpr int WARPS_K = Cta_tile::WARPS_K; + + // Warp order (fastest to slowest): m => n => k + // 2x2: 2,2,1 then 2,1,2: mask_m = 0x20, mask_k = 0x40, div_m = 32, div_k = 64 + // 1x4: 1,4,1 then 1,1,4: mask_m = 0x00, mask_k = 0x60, div_m = X, div_k = 32 + // 1x8: 1,8,1 then 1,1,8: mask_m = 0x00, mask_k = 0xe0, div_m = X, div_k = 32 + static_assert(WARPS_N == 1, ""); + + // A thread holds 4 elts of 4B. One slice of 32 elts has 128B. + // Two MMAs in N constitute one slice + + // the slice offset that depends on ni and has to be added later + static_assert(VEC / ELTS_PER_STS == 8, ""); // 8 columns of 4 elements + if (WARPS_M == 2 && WARPS_N == 1 && WARPS_K == 2) { + write_row = (tidx & 0x1c) / 4 + (tidx & 0x20) / 2; // warp_m * 16 rows + write_col = (tidx & 0x03) + (tidx & 0x40) / 8; // warp_k * VEC / ELTS_PER_STS + } else { + write_row = (tidx & 0x1c) / 4; + write_col = (tidx & 0x03) + (tidx & 0xe0) / 4; // warp_k * VEC / ELTS_PER_STS + } + write_col ^= (write_row & 0x01) * 4; // left or right 64B + + // this->smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS; + + int read_row = tidx / THREADS_PER_ROW; + int read_col = tidx % THREADS_PER_ROW; + read_col ^= (read_row & 0x01) * 4; + this->smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; + } + + // Store the accumulators. + template + inline __device__ void store(Accumulator const (&acc)[M][N], int mi) { +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { + int const slice = ni / NUM_SLICES; + int col = write_col ^ ((ni & 1) * 4); + + uint32_t smem_write_ = smem_ + write_row * BYTES_PER_ROW + col * BYTES_PER_STS; + + // Extract the elements. + uint4 row_0, row_1; + + Regs_to_rows::extract(acc[mi][ni], row_0, row_1); + + // Each thread of a quad writes 16B per STS -> 64B per store. Account for + // 2 -> 128B. + int imm_0 = (slice * ROWS_PER_SLICE + 0) * BYTES_PER_ROW; + int imm_1 = (slice * ROWS_PER_SLICE + 8) * BYTES_PER_ROW; + + // Store the elements. + fmha::sts(smem_write_ + imm_0, row_0); + fmha::sts(smem_write_ + imm_1, row_1); + } + } + + // Load the output fragments. + inline __device__ void load(uint4 (&out)[LDS_PER_LOOP]) const { +#pragma unroll + for (int ii = 0; ii < LDS_PER_LOOP; ++ii) { + // Load the elements before the reduction (split-K). + uint4 tmp[Cta_tile::WARPS_K]; +#pragma unroll + for (int jj = 0; jj < Cta_tile::WARPS_K; ++jj) { + int imm = ii * ROWS_PER_LDS * BYTES_PER_ROW + jj * VEC * BYTES_PER_ELEMENT; + fmha::lds(tmp[jj], smem_read_ + imm); + } + +// Perform the reduction. +#pragma unroll + for (int jj = 1; jj < Cta_tile::WARPS_K; ++jj) { + add4(tmp[0], tmp[jj]); + } + + // Write to out. + out[ii] = tmp[0]; + } + } + + int write_row; + int write_col; + uint32_t smem_write_; + uint32_t smem_read_; + uint32_t smem_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/smem_tile_qkv.h b/csrc/fmha_v2/fmha/smem_tile_qkv.h new file mode 100644 index 0000000000..32caaadb3a --- /dev/null +++ b/csrc/fmha_v2/fmha/smem_tile_qkv.h @@ -0,0 +1,592 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_qkv_interleaved + : public fmha::Smem_tile_without_skews { + // The traits class. + using Traits = Traits_; + // The base class. + using Base = fmha::Smem_tile_without_skews; + + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The fragment. + + // The size of a single LDS in bytes. + enum { BYTES_PER_LDS = 16 }; + + enum { ROWS_PER_WARP = Cta_tile::THREADS_PER_WARP / Base::THREADS_PER_ROW }; + + using Fragment_a = fmha::Fragment_a; + using Fragment_b = fmha::Fragment_b; + + inline __device__ Smem_tile_qkv_interleaved(char* smem, int tidx) : Base(smem, tidx) {} + + uint32_t offset; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_qk_interleaved_a_base : public Smem_tile_qkv_interleaved { + using Base = Smem_tile_qkv_interleaved; + + static_assert(Base::THREADS_PER_ROW == 128 / 16, ""); + + enum { SMEM_ROWS_PER_WARP = Base::ROWS_PER_WARP }; + + static_assert(SMEM_ROWS_PER_WARP == 4, ""); + + using Mma_tile = typename Base::Mma_tile; + using Fragment = typename Base::Fragment_a; + + inline __device__ Smem_tile_qk_interleaved_a_base(char* smem, int tidx) : Base(smem, tidx) { + static_assert(Cta_tile::WARPS_K == 1, ""); + static_assert(Cta_tile::WARPS_M == 1 || Cta_tile::WARPS_M == 2, ""); + static_assert(Cta_tile::WARPS_N == 2 || Cta_tile::WARPS_N == 4, ""); + + constexpr int WARPS_M = Cta_tile::WARPS_M; + constexpr int WARPS_N = Cta_tile::WARPS_N; + constexpr int WARPS_K = Cta_tile::WARPS_K; + + constexpr int WARP_MASK_M = fmha::Warp_masks::M; + constexpr int WARP_DIV_M = 1 * 1 * Cta_tile::THREADS_PER_WARP; + + int const warp_m = (tidx & WARP_MASK_M) / WARP_DIV_M; + + /* Read address layout for ldsm: + * [ 0 16 1 17 2 18 3 19] + * [20 4 21 5 22 6 23 7] + * [ 8 24 9 25 10 26 11 27] + * [28 12 29 13 30 14 31 15] + */ + int read_row = (tidx & 0x04) / 4 + (tidx & 0x08) / 4 + warp_m * SMEM_ROWS_PER_WARP; + int read_col = (tidx & 0x03) * 2 + (tidx & 0x10) / 16; + read_col ^= (read_row & 0x01); + + this->offset = read_row * Base::BYTES_PER_ROW + read_col * Base::BYTES_PER_LDS; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_qk_interleaved_a {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_qk_interleaved_a + : public Smem_tile_qk_interleaved_a_base { + using Traits = fmha::Volta_imma_int8_int32_traits; + using Base = Smem_tile_qk_interleaved_a_base; + using Mma_tile = typename Base::Mma_tile; + using Fragment = typename Base::Fragment; + + inline __device__ Smem_tile_qk_interleaved_a(char* smem, int tidx) : Base(smem, tidx) {} + + inline __device__ void load(Fragment (&frag)[Mma_tile::MMAS_M], int ki) { + int slice = ki / 2; + +#pragma unroll + for (int mi = 0; mi < Mma_tile::MMAS_M; mi++) { + // the data for the second slice sits below the first slice + uint32_t read_ptr = this->smem_ + this->offset + slice * Base::ROWS * Base::BYTES_PER_ROW / 2; + uint2 data; + ldsm_with_lds( + data, read_ptr + mi * Cta_tile::WARPS_M * Base::SMEM_ROWS_PER_WARP * Base::BYTES_PER_ROW); + static_assert(Fragment::NUM_REGS == 2, ""); + frag[mi].reg(0) = data.x; + frag[mi].reg(1) = data.y; + } + + this->offset ^= 16; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_qk_interleaved_a + : public Smem_tile_qk_interleaved_a_base { + using Traits = fmha::Turing_imma_int8_int32_traits; + using Base = Smem_tile_qk_interleaved_a_base; + using Mma_tile = typename Base::Mma_tile; + using Fragment = typename Base::Fragment; + + inline __device__ Smem_tile_qk_interleaved_a(char* smem, int tidx) : Base(smem, tidx) {} + + inline __device__ void load(Fragment (&frag)[Mma_tile::MMAS_M], int ki) { + int slice = ki / 2; + +#pragma unroll + for (int mi = 0; mi < Mma_tile::MMAS_M; mi++) { + // the data for the second slice sits below the first slice + uint32_t read_ptr = this->smem_ + this->offset + slice * Base::ROWS * Base::BYTES_PER_ROW / 2; + uint2 data; + fmha::ldsm( + data, read_ptr + mi * Cta_tile::WARPS_M * Base::SMEM_ROWS_PER_WARP * Base::BYTES_PER_ROW); + static_assert(Fragment::NUM_REGS == 2, ""); + frag[mi].reg(0) = data.x; + frag[mi].reg(1) = data.y; + } + + this->offset ^= 16; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_qk_interleaved_a + : public Smem_tile_qk_interleaved_a_base { + using Traits = fmha::Ampere_imma_int8_int32_traits; + using Base = Smem_tile_qk_interleaved_a_base; + using Mma_tile = typename Base::Mma_tile; + using Fragment = typename Base::Fragment; + + inline __device__ Smem_tile_qk_interleaved_a(char* smem, int tidx) : Base(smem, tidx) {} + + inline __device__ void load(Fragment (&frag)[Mma_tile::MMAS_M], int ki) { +#pragma unroll + for (int mi = 0; mi < Mma_tile::MMAS_M; mi++) { + // the data for the second slice sits below the first slice + uint32_t read_ptr = this->smem_ + this->offset + ki * Base::ROWS * Base::BYTES_PER_ROW / 2; + uint4 data; + fmha::ldsm( + data, read_ptr + mi * Cta_tile::WARPS_M * Base::SMEM_ROWS_PER_WARP * Base::BYTES_PER_ROW); + static_assert(Fragment ::NUM_REGS == 4, ""); + frag[mi].reg(0) = data.x; + frag[mi].reg(1) = data.y; + frag[mi].reg(2) = data.z; + frag[mi].reg(3) = data.w; + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_qk_interleaved_b_base : public Smem_tile_qkv_interleaved { + using Base = Smem_tile_qkv_interleaved; + + using Mma_tile = typename Base::Mma_tile; + using Fragment = typename Base::Fragment_b; + + inline __device__ Smem_tile_qk_interleaved_b_base(char* smem, int tidx) : Base(smem, tidx) { + constexpr int WARPS_M = Cta_tile::WARPS_M; + constexpr int WARPS_N = Cta_tile::WARPS_N; + constexpr int WARPS_K = Cta_tile::WARPS_K; + + // 2x2: 2,2,1 then 2,1,2 + // 1x4: 1,4,1 then 1,1,4 + static_assert(WARPS_K == 1, ""); + + constexpr int WARP_MASK_N = fmha::Warp_masks::N; + constexpr int WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP; + + // Only need to care about warp_n, because if warps_m > 1, both of them should load + // the same data + int const warp = (tidx & WARP_MASK_N) / WARP_DIV_N; + + /* transpose the order of the LDSMs: first along K, then along N + * [ 0 8 1 9 2 10 3 11] + * [12 4 13 5 14 6 15 7] + * [16 24 17 25 18 26 19 27] + * [28 20 29 21 30 22 31 23] + */ + int read_row = (tidx & 0x04) / 4 + (tidx & 0x10) / 8 + warp * Base::ROWS_PER_WARP; + int read_col = (tidx & 0x03) * 2 + (tidx & 0x08) / 8; + read_col ^= (read_row & 0x01); + + this->offset = read_row * Base::BYTES_PER_ROW + read_col * Base::BYTES_PER_LDS; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_qk_interleaved_b {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_qk_interleaved_b + : public Smem_tile_qk_interleaved_b_base { + using Traits = fmha::Volta_imma_int8_int32_traits; + using Base = Smem_tile_qk_interleaved_b_base; + + using Mma_tile = typename Base::Mma_tile; + using Fragment = typename Base::Fragment; + + inline __device__ Smem_tile_qk_interleaved_b(char* smem, int tidx) : Base(smem, tidx) { + constexpr int WARPS_M = Cta_tile::WARPS_M; + constexpr int WARPS_N = Cta_tile::WARPS_N; + constexpr int WARPS_K = Cta_tile::WARPS_K; + constexpr int WARP_MASK_N = fmha::Warp_masks::N; + constexpr int WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP; + + // Only need to care about warp_n, because if warps_m > 1, both of them should load + // the same data + int const warp = (tidx & WARP_MASK_N) / WARP_DIV_N; + + int read_row = (tidx & 0x04) / 4 + (tidx & 0x08) / 4 + warp * Base::ROWS_PER_WARP; + int read_col = (tidx & 0x03) * 2 + (tidx & 0x10) / 16; + read_col ^= (read_row & 0x01); + + this->offset = read_row * Base::BYTES_PER_ROW + read_col * Base::BYTES_PER_LDS; + } + + inline __device__ void load(Fragment (&frag)[Mma_tile::MMAS_N], int ki) { + int slice = ki / 2; +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ni++) { + uint32_t read_ptr = this->smem_ + this->offset + slice * Base::ROWS * Base::BYTES_PER_ROW / 2; + uint2 data; + ldsm_with_lds(data, + read_ptr + ni * Base::ROWS_PER_WARP * Cta_tile::WARPS_N * Base::BYTES_PER_ROW); + static_assert(Fragment ::NUM_REGS == 2, ""); + frag[ni].reg(0) = data.x; + frag[ni].reg(1) = data.y; + } + this->offset ^= 16; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_qk_interleaved_b + : public Smem_tile_qk_interleaved_b_base { + using Traits = fmha::Turing_imma_int8_int32_traits; + using Base = Smem_tile_qk_interleaved_b_base; + + using Mma_tile = typename Base::Mma_tile; + using Fragment = typename Base::Fragment; + + inline __device__ Smem_tile_qk_interleaved_b(char* smem, int tidx) : Base(smem, tidx) { + constexpr int WARPS_M = Cta_tile::WARPS_M; + constexpr int WARPS_N = Cta_tile::WARPS_N; + constexpr int WARPS_K = Cta_tile::WARPS_K; + constexpr int WARP_MASK_N = fmha::Warp_masks::N; + constexpr int WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP; + + // Only need to care about warp_n, because if warps_m > 1, both of them should load + // the same data + int const warp = (tidx & WARP_MASK_N) / WARP_DIV_N; + + int read_row = (tidx & 0x04) / 4 + (tidx & 0x08) / 4 + warp * Base::ROWS_PER_WARP; + int read_col = (tidx & 0x03) * 2 + (tidx & 0x10) / 16; + read_col ^= (read_row & 0x01); + + this->offset = read_row * Base::BYTES_PER_ROW + read_col * Base::BYTES_PER_LDS; + } + + inline __device__ void load(Fragment (&frag)[Mma_tile::MMAS_N], int ki) { + int slice = ki / 2; +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ni++) { + uint32_t read_ptr = this->smem_ + this->offset + slice * Base::ROWS * Base::BYTES_PER_ROW / 2; + uint2 data; + fmha::ldsm(data, + read_ptr + ni * Base::ROWS_PER_WARP * Cta_tile::WARPS_N * Base::BYTES_PER_ROW); + static_assert(Fragment ::NUM_REGS == 2, ""); + frag[ni].reg(0) = data.x; + frag[ni].reg(1) = data.y; + } + this->offset ^= 16; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_qk_interleaved_b + : public Smem_tile_qk_interleaved_b_base { + using Traits = fmha::Ampere_imma_int8_int32_traits; + using Base = Smem_tile_qk_interleaved_b_base; + + using Mma_tile = typename Base::Mma_tile; + using Fragment = typename Base::Fragment; + + inline __device__ Smem_tile_qk_interleaved_b(char* smem, int tidx) : Base(smem, tidx) {} + + inline __device__ void load(Fragment (&frag)[Mma_tile::MMAS_N], int ki) { +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ni++) { + uint32_t read_ptr = this->smem_ + this->offset + ki * Base::ROWS * Base::BYTES_PER_ROW / 2; + uint4 data; + fmha::ldsm(data, + read_ptr + ni * Base::ROWS_PER_WARP * Cta_tile::WARPS_N * Base::BYTES_PER_ROW); + static_assert(Fragment ::NUM_REGS == 4, ""); + frag[ni].reg(0) = data.x; + frag[ni].reg(1) = data.y; + frag[ni].reg(2) = data.z; + frag[ni].reg(3) = data.w; + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_v_interleaved_b_base + : public Smem_tile_qkv_interleaved { + using Base = Smem_tile_qkv_interleaved; + + using Mma_tile = typename Base::Mma_tile; + // TODO Row or col? + using Fragment = typename Base::Fragment_b; + + inline __device__ Smem_tile_v_interleaved_b_base(char* smem, int tidx) : Base(smem, tidx) { + // // DEBUG. + // static_assert( Cta_tile::N == 64, "" ); + // // END OF DEBUG. + + constexpr int WARPS_M = Cta_tile::WARPS_M; + constexpr int WARPS_N = Cta_tile::WARPS_N; + constexpr int WARPS_K = Cta_tile::WARPS_K; + + // 2x2: 2,2,1 then 2,1,2 + // 1x4: 1,4,1 then 1,1,4 + static_assert(WARPS_N == 1, ""); + + // Don't need to consider WARP M. For two warps in M, both would read the same tile + constexpr int WARP_MASK_K = fmha::Warp_masks::K; + constexpr int WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP; + + // the static assert above ensures, that only warp_m or warp_k is non-zero + int const warp = (tidx & WARP_MASK_K) / WARP_DIV_K; + + /* LDSM.T addresses: warps are split in two to match BMM1-GEMM-N (= BMM2-GEMM-K) register + * layout + * <== GEMM-N = D = 64 ==> + * [ 0: 0 0 1 0 2 0 3 0] WARP 0 + * [ 1: 0 4 0 5 0 6 0 7] + * [ 2: 8 0 9 0 10 0 11 0] + * [ 3: 0 12 0 13 0 14 0 15] + * [ 4: 0 0 0 0 0 0 0 0] WARP 1 + * [ 5: 0 0 0 0 0 0 0 0] + * [ 6: 0 0 0 0 0 0 0 0] + * [ 7: 0 0 0 0 0 0 0 0] + * [ 8: 0 0 0 0 0 0 0 0] WARP 2 + * [ 9: 0 0 0 0 0 0 0 0] + * [10: 0 0 0 0 0 0 0 0] + * [11: 0 0 0 0 0 0 0 0] + * [12: 0 0 0 0 0 0 0 0] WARP 3 + * [13: 0 0 0 0 0 0 0 0] + * [14: 0 0 0 0 0 0 0 0] + * [15: 0 0 0 0 0 0 0 0] + * [16: 16 0 17 0 18 0 19 0] WARP 0 + * [17: 0 20 0 21 0 22 0 23] + * [18: 24 0 25 0 26 0 27 0] + * [19: 0 28 0 29 0 30 0 31] + * etc ... + */ + + // TODO this is a bit misleading, as 4 rows per warp applies to the + // row-major tiles above. In this smem tile, a warp actually owns 8 rows in + // SMEM, but we have 4 rows per slice + + // TODO would be good to rename to SMEM_ROWS_PER_WARP to make this clearer + static_assert(Base::ROWS_PER_WARP == 4, ""); + + read_row = ((tidx & 0x0f) / 4) + warp * Base::ROWS_PER_WARP; + read_col = (tidx & 0x03) * 2; + read_col ^= (read_row & 0x01); + + // this->offset = read_row * Base::BYTES_PER_ROW + read_col * Base::BYTES_PER_LDS; + } + + int read_row; + int read_col; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_v_interleaved_b {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_v_interleaved_b + : public Smem_tile_v_interleaved_b_base { + using Traits = fmha::Volta_imma_int8_int32_traits; + using Base = Smem_tile_v_interleaved_b_base; + using Mma_tile = typename Base::Mma_tile; + using Fragment = typename Base::Fragment_b; + + // Ctor. + inline __device__ Smem_tile_v_interleaved_b(char* smem, int tidx) : Base(smem, tidx) {} + + // Load fragments from shared memory. + inline __device__ void load(Fragment (&frag)[Mma_tile::MMAS_N], int ki) { + // static_assert(Mma_tile::MMAS_K == 4, ""); + static_assert(Mma_tile::MMAS_N == 4, ""); + static_assert(Base::ROWS_PER_WARP == 4, ""); + // static_assert(Cta_tile::WARPS_K == 2, ""); + + int offset_k = ki * Cta_tile::WARPS_K * Base::ROWS_PER_WARP; +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ni++) { + uint32_t offset = (this->read_row + offset_k) * Base::BYTES_PER_ROW + + (this->read_col ^ (ni & 1)) * Base::BYTES_PER_LDS; + + // for the next 32B in N, we have to jump down K rows, so K / 4 rows in + // smem, which stores 4 canonical 32B rows per 128B + offset += (ni / 2) * Cta_tile::K / 4 * Base::BYTES_PER_ROW; + uint32_t read_ptr = this->smem_ + offset; // + ki * Base::ROWS * Base::BYTES_PER_ROW / 2; + uint2 data = {0, 0}; + ldsmt_with_lds(data, read_ptr); + static_assert(Fragment ::NUM_REGS == 2, ""); + swizzle_rows(frag[ni].reg(0), frag[ni].reg(1), data.x, data.y); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_v_interleaved_b + : public Smem_tile_v_interleaved_b_base { + using Traits = fmha::Turing_imma_int8_int32_traits; + using Base = Smem_tile_v_interleaved_b_base; + using Mma_tile = typename Base::Mma_tile; + using Fragment = typename Base::Fragment_b; + + // Ctor. + inline __device__ Smem_tile_v_interleaved_b(char* smem, int tidx) : Base(smem, tidx) {} + + // Load fragments from shared memory. + inline __device__ void load(Fragment (&frag)[Mma_tile::MMAS_N], int ki) { + static_assert(Mma_tile::MMAS_N == 4, ""); + static_assert(Base::ROWS_PER_WARP == 4, ""); + + int offset_k = ki * Cta_tile::WARPS_K * Base::ROWS_PER_WARP; +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ni++) { + uint32_t offset = (this->read_row + offset_k) * Base::BYTES_PER_ROW + + (this->read_col ^ (ni & 1)) * Base::BYTES_PER_LDS; + // for the next 32B in N, we have to jump down K rows, so K / 4 rows in + // smem, which stores 4 canonical 32B rows per 128B + offset += (ni / 2) * Cta_tile::K / 4 * Base::BYTES_PER_ROW; + uint32_t read_ptr = this->smem_ + offset; // + ki * Base::ROWS * Base::BYTES_PER_ROW / 2; + uint2 data = {0, 0}; + fmha::ldsmt(data, read_ptr); + static_assert(Fragment ::NUM_REGS == 2, ""); + swizzle_rows(frag[ni].reg(0), frag[ni].reg(1), data.x, data.y); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_v_interleaved_b + : public Smem_tile_v_interleaved_b_base { + // The instruction traits. + using Traits = fmha::Ampere_imma_int8_int32_traits; + // The base class. + using Base = Smem_tile_v_interleaved_b_base; + // The tile of MMAs. + using Mma_tile = typename Base::Mma_tile; + // The fragment loaded. + using Fragment = typename Base::Fragment_b; + + // Ctor. + inline __device__ Smem_tile_v_interleaved_b(char* smem, int tidx) : Base(smem, tidx) {} + + // Load from shared memory. + inline __device__ void load(Fragment (&frag)[Mma_tile::MMAS_N], int ki) { + int offset_k = ki * Cta_tile::WARPS_K * Base::ROWS_PER_WARP * 2; + static_assert(Cta_tile::K != 192 || Mma_tile::MMAS_K == 2, ""); +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ni++) { + uint32_t offset = (this->read_row + offset_k) * Base::BYTES_PER_ROW + + (this->read_col ^ (ni & 1)) * Base::BYTES_PER_LDS; + + // For the next 32B in N, we have to jump down K rows, so K / 4 rows in smem, which + // stores 4 canonical 32B rows per 128B. + offset += (ni / 2) * Cta_tile::K / 4 * Base::BYTES_PER_ROW; + uint32_t read_ptr = this->smem_ + offset; // + ki * Base::ROWS * Base::BYTES_PER_ROW / 2; + uint2 data0 = {0, 0}; + uint2 data1 = {0, 0}; + fmha::ldsmt(data0, read_ptr); + + if (Cta_tile::K != 192 || ki == 0) { + static_assert(Cta_tile::K != 192 || Mma_tile::MMAS_K == 2); + // For 192, with 4 warps, we need 128 rows of K, so for the second ldsm, we need + // only 2x instead of 4x. + int imm = Cta_tile::WARPS_K * Base::ROWS_PER_WARP * Base::BYTES_PER_ROW; + fmha::ldsmt(data1, read_ptr + imm); + } + + static_assert(Fragment ::NUM_REGS == 4, ""); + swizzle_rows(frag[ni].reg(0), frag[ni].reg(2), data0.x, data0.y); + swizzle_rows(frag[ni].reg(1), frag[ni].reg(3), data1.x, data1.y); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/smem_tile_v.h b/csrc/fmha_v2/fmha/smem_tile_v.h new file mode 100644 index 0000000000..67a02f37ca --- /dev/null +++ b/csrc/fmha_v2/fmha/smem_tile_v.h @@ -0,0 +1,1008 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_v {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template class Rows_per_xor_pattern, + int BUFFERS_PER_TILE = 1> +struct Smem_tile_v_hmma { + using Base = Smem_tile_without_skews::VALUE, 1>; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_v + : public Smem_tile_v_hmma::Base { + // The traits class. + using Traits = fmha::Volta_hmma_fp16_16x16x16_traits; + // The base class. + using Base = typename Smem_tile_v_hmma::Base; + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + // The fragment. + using Fragment = fmha::Fragment_b; + + // The size of a single LDS in bytes. + enum { BYTES_PER_LDS = 16 }; + + // Ctor. + inline __device__ Smem_tile_v(void* smem, int tidx) : Base(smem, tidx) { + // Warps. + enum { WARPS_M = Cta_tile::WARPS_M }; + + enum { WARPS_N = Cta_tile::WARPS_N }; + + enum { WARPS_K = Cta_tile::WARPS_K }; + + // Determine the config. + enum { WARPS_2x1x2 = WARPS_M == 2 && WARPS_N == 1 && WARPS_K == 2 }; + + enum { WARPS_1x1x8 = WARPS_M == 1 && WARPS_N == 1 && WARPS_K == 8 }; + + enum { WARPS_1x1x4 = WARPS_M == 1 && WARPS_N == 1 && WARPS_K == 4 }; + + // Flash Attention uses WARPS_4x1x1 + enum { WARPS_4x1x1 = WARPS_M == 4 && WARPS_N == 1 && WARPS_K == 1 }; + + // The row/col read by the thread. + int read_row, read_col; + + // SEQLEN == 128 and N == 16. + if (WARPS_2x1x2 && Cta_tile::N == 16) { + read_row = (tidx & 0x40) / 16 + (tidx & 0x08) / 8; + read_col = (tidx & 0x10) / 16 + (tidx & 0x03) * 2; + + // SEQLEN == 128 and N == 32. + } else if (WARPS_2x1x2 && Cta_tile::N == 32) { + read_row = (tidx & 0x40) / 8 + (tidx & 0x08) / 4 + (tidx & 0x02) / 2; + read_col = (tidx & 0x10) / 16 + (tidx & 0x01) * 4; + + // SEQLEN == 128 and N == 64. + } else if (WARPS_2x1x2 && Cta_tile::N == 64) { + read_row = (tidx & 0x40) / 4 + (tidx & 0x08) / 2 + (tidx & 0x03); + read_col = (tidx & 0x10) / 16; + + // SEQLEN == 256, 512 and N == 16. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 16) { + read_row = (tidx & 0xe0) / 8 + (tidx & 0x08) / 8; + read_col = (tidx & 0x10) / 16 + (tidx & 0x03) * 2; + + // SEQLEN == 256, 512 and N == 32. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 32) { + read_row = (tidx & 0xe0) / 4 + (tidx & 0x08) / 4 + (tidx & 0x02) / 2; + read_col = (tidx & 0x10) / 16 + (tidx & 0x01) * 4; + + // SEQLEN == 256, 384 and 512 and N == 64. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && + (Cta_tile::N == 64 || Cta_tile::N == 128 || Cta_tile::N == 256)) { + read_row = (tidx & 0xe0) / 2 + (tidx & 0x08) / 2 + (tidx & 0x03); + read_col = (tidx & 0x10) / 16; + + // ANY SEQLEN and N == 16. + } else if (WARPS_4x1x1 && Cta_tile::N == 16) { + read_row = (tidx & 0x08) / 8; + read_col = (tidx & 0x10) / 16 + (tidx & 0x03) * 2; + + // ANY SEQLEN and N == 32. + } else if (WARPS_4x1x1 && Cta_tile::N == 32) { + read_row = (tidx & 0x08) / 4 + (tidx & 0x02) / 2; + read_col = (tidx & 0x10) / 16 + (tidx & 0x01) * 4; + + // ANY SEQLEN and N == 64/128/256. + } else if (WARPS_4x1x1 && (Cta_tile::N == 64 || Cta_tile::N == 128 || Cta_tile::N == 256)) { + read_row = (tidx & 0x08) / 2 + (tidx & 0x03); + read_col = (tidx & 0x10) / 16; + + // Not supported! + } else { + assert(false); + } + + // Apply the XOR for the column. + read_col ^= read_row % Base::ROWS_PER_XOR_PATTERN; + + // The shared memory offset. + this->smem_read_offset_ = read_row * Base::BYTES_PER_ROW + read_col * BYTES_PER_LDS; + } + + // Load from shared memory. + inline __device__ void load(Fragment (&b)[Mma_tile::VALID_MMAS_N], int ki) { +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { + // The column offset. + int offset = this->smem_read_offset_ ^ (ni * 2 * BYTES_PER_LDS); + + // Skip N paddings + if (ni < Mma_tile::VALID_MMAS_N) { + // The rows. + int row_0 = ki * 16 * Cta_tile::WARPS_K + 0; + int row_1 = ki * 16 * Cta_tile::WARPS_K + 8; + + // Load the data using 2x LDS.128. + uint4 tmp; + fmha::lds(tmp, this->smem_ + offset + row_0 * Base::BYTES_PER_ROW_BEFORE_PACKING); + b[ni].reg(0) = tmp.x; + b[ni].reg(1) = tmp.y; + b[ni].reg(2) = tmp.z; + b[ni].reg(3) = tmp.w; + + fmha::lds(tmp, this->smem_ + offset + row_1 * Base::BYTES_PER_ROW_BEFORE_PACKING); + b[ni].reg(4) = tmp.x; + b[ni].reg(5) = tmp.y; + b[ni].reg(6) = tmp.z; + b[ni].reg(7) = tmp.w; + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_v_turing_hmma + : public Smem_tile_v_hmma::Base { + // The base class. + using Base = typename Smem_tile_v_hmma::Base; + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + // The fragment. + using Fragment = fmha::Fragment_b; + + // The size of a single LDS in bytes. + enum { BYTES_PER_LDS = 16 }; + + // Ctor. + inline __device__ Smem_tile_v_turing_hmma(void* smem, int tidx) : Base(smem, tidx) { + // Warps. + enum { WARPS_M = Cta_tile::WARPS_M }; + + enum { WARPS_N = Cta_tile::WARPS_N }; + + enum { WARPS_K = Cta_tile::WARPS_K }; + + // Determine the config. + enum { WARPS_2x1x2 = WARPS_M == 2 && WARPS_N == 1 && WARPS_K == 2 }; + + enum { WARPS_1x1x8 = WARPS_M == 1 && WARPS_N == 1 && WARPS_K == 8 }; + + enum { WARPS_1x1x4 = WARPS_M == 1 && WARPS_N == 1 && WARPS_K == 4 }; + + // Flash Attention uses WARPS_4x1x1 + enum { WARPS_4x1x1 = WARPS_M == 4 && WARPS_N == 1 && WARPS_K == 1 }; + + // The row/col read by the thread. + int read_row, read_col; + + // SEQLEN == 128 and N == 16. + if (WARPS_2x1x2 && Cta_tile::N == 16) { + read_row = (tidx & 0x40) / 16 + (tidx & 0x04) / 4; + read_col = (tidx & 0x03) * 2 + (tidx & 0x04) / 4; + + // SEQLEN == 128 and N == 32. + } else if (WARPS_2x1x2 && Cta_tile::N == 32) { + read_row = (tidx & 0x40) / 8 + (tidx & 0x06) / 2; + read_col = (tidx & 0x01) * 4 + (tidx & 0x06) / 2; + + // SEQLEN == 128 and N == 64. + } else if (WARPS_2x1x2 && Cta_tile::N == 64) { + read_row = (tidx & 0x40) / 4 + (tidx & 0x07); + read_col = (tidx & 0x07); + + // SEQLEN == 256, 512 and N == 16. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 16) { + read_row = (tidx & 0xe0) / 8 + (tidx & 0x04) / 4; + read_col = (tidx & 0x03) * 2 + (tidx & 0x04) / 4; + + // SEQLEN == 256, 512 and N == 32. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 32) { + read_row = (tidx & 0xe0) / 4 + (tidx & 0x06) / 2; + read_col = (tidx & 0x01) * 4 + (tidx & 0x06) / 2; + + // SEQLEN == 256, 384, 512 and N == 64, 128, 256. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && + (Cta_tile::N == 64 || Cta_tile::N == 128 || Cta_tile::N == 256)) { + read_row = (tidx & 0xe0) / 2 + (tidx & 0x07); + read_col = (tidx & 0x07); + + // ANY SEQLEN and N == 16. + } else if (WARPS_4x1x1 && Cta_tile::N == 16) { + read_row = (tidx & 0x04) / 4; + read_col = (tidx & 0x03) * 2 + (tidx & 0x04) / 4; + + // ANY SEQLEN and N == 32. + } else if (WARPS_4x1x1 && Cta_tile::N == 32) { + read_row = (tidx & 0x06) / 2; + read_col = (tidx & 0x01) * 4 + (tidx & 0x06) / 2; + + // ANY SEQLEN and N == 64/128/256. + } else if ((WARPS_4x1x1) && (Cta_tile::N == 64 || Cta_tile::N == 128 || Cta_tile::N == 256)) { + read_row = (tidx & 0x07); + read_col = (tidx & 0x07); + + // Not supported! + } else { + assert(false); + } + + // The 2nd HMMA. + read_col ^= (tidx & 0x08) / 8; + + // The shared memory offset. + this->smem_read_offset_ = read_row * Base::BYTES_PER_ROW + read_col * BYTES_PER_LDS; + } + + // Load from shared memory. + inline __device__ void load(Fragment (&b)[Mma_tile::VALID_MMAS_N], int ki) { +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { + // The amount of row packing. + enum { ROW_PACKING = Base::BYTES_PER_ROW / Base::BYTES_PER_ROW_BEFORE_PACKING }; + + // Skip N paddings + if (ni < Mma_tile::VALID_MMAS_N) { + // For even values of k value we jump by 16*WARPS_K rows and for odd, we jump by 8 rows. + int row = (ki / 2) * 16 * Cta_tile::WARPS_K / ROW_PACKING + (ki % 2) * 8 / ROW_PACKING; + + // Load the data using LDSM.MT88.2. + uint2 tmp; + fmha::ldsmt(tmp, this->smem_ + this->smem_read_offset_ + row * Base::BYTES_PER_ROW); + b[ni].reg(0) = tmp.x; + b[ni].reg(1) = tmp.y; + } + + // Move to the next N position. + if (Mma_tile::MMAS_N == 1) { + ; + } else if (Mma_tile::MMAS_N == 2) { + this->smem_read_offset_ ^= BYTES_PER_LDS * 2; + } else if (Mma_tile::MMAS_N == 4) { + this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6); + } else if (Mma_tile::MMAS_N == 8) { + this->smem_read_offset_ ^= BYTES_PER_LDS * ((ni & 1) == 0 ? 2 : ((ni & 3) == 3 ? 14 : 6)); + } else if (Mma_tile::MMAS_N == 16) { + this->smem_read_offset_ ^= BYTES_PER_LDS * ((ni & 1) == 0 ? 2 + : ((ni & 7) == 7) ? 30 + : (((ni & 3) == 3) ? 14 : 6)); + } else { + assert(false); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_v + : public Smem_tile_v_turing_hmma { + // The base class. + using Base = Smem_tile_v_turing_hmma; + + // Ctor. + inline __device__ Smem_tile_v(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_v + : public Smem_tile_v_turing_hmma { + // The base class. + using Base = Smem_tile_v_turing_hmma; + + // Ctor. + inline __device__ Smem_tile_v(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template class Rows_per_xor_pattern, + int BUFFERS_PER_TILE = 1> +struct Smem_tile_v_imma { + using Base = Smem_tile_without_skews::VALUE, 1>; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_v + : public Smem_tile_v_imma::Base { + // The traits class. + using Traits = Volta_imma_int8_int32_traits; + // The base class. + using Base = typename Smem_tile_v_imma::Base; + + // DEBUG. + static_assert(Base::BYTES_PER_ROW / Base::BYTES_PER_ROW_BEFORE_PACKING == 2, ""); + // END OF DEBUG. + + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + // The fragment. + using Fragment = fmha::Fragment_b; + + // The size of a single LDS in bytes. + enum { BYTES_PER_LDS = 16 }; + + // Ctor. + inline __device__ Smem_tile_v(void* smem, int tidx) : Base(smem, tidx) { + // The row/col read by the thread. + int read_row, read_col; + + // Warps. + enum { WARPS_M = Cta_tile::WARPS_M }; + + enum { WARPS_N = Cta_tile::WARPS_N }; + + enum { WARPS_K = Cta_tile::WARPS_K }; + + // Determine the config. + enum { WARPS_2x1x2 = WARPS_M == 2 && WARPS_N == 1 && WARPS_K == 2 }; + + enum { WARPS_1x1x8 = WARPS_M == 1 && WARPS_N == 1 && WARPS_K == 8 }; + + enum { WARPS_1x1x4 = WARPS_M == 1 && WARPS_N == 1 && WARPS_K == 4 }; + + // SEQLEN == 128 and N == 16. + if (WARPS_2x1x2 && Cta_tile::N == 16) { + read_row = (tidx & 0x40) / 32 + (tidx & 0x08) / 8; + read_col = (tidx & 0x07); + + // SEQLEN == 128 and N == 32. + } else if (WARPS_2x1x2 && Cta_tile::N == 32) { + read_row = (tidx & 0x40) / 16 + (tidx & 0x0c) / 4; + read_col = (tidx & 0x03) * 2 + (tidx & 0x04) / 4; + + // SEQLEN == 128 and N == 64. + } else if (WARPS_2x1x2 && Cta_tile::N == 64) { + read_row = (tidx & 0x40) / 8 + (tidx & 0x0e) / 2; + read_col = (tidx & 0x01) * 4 + (tidx & 0x06) / 2; + + // SEQLEN == 256, 512 and N == 16. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 16) { + read_row = (tidx & 0xe0) / 16 + (tidx & 0x08) / 8; + read_col = (tidx & 0x07); + + // SEQLEN == 256, 512 and N == 32. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 32) { + read_row = (tidx & 0xe0) / 8 + (tidx & 0x0c) / 4; + read_col = (tidx & 0x03) * 2 + (tidx & 0x04) / 4; + + // SEQLEN == 256, 384, 512 and N == 64. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 64) { + read_row = (tidx & 0xe0) / 4 + (tidx & 0x0e) / 2; + read_col = (tidx & 0x01) * 4 + (tidx & 0x06) / 2; + + // Not supported. + } else { + assert(false); + } + + // The shared memory offset. + this->smem_read_offset_ = read_row * Base::BYTES_PER_ROW + read_col * BYTES_PER_LDS; + } + + // Load from shared memory. + inline __device__ void load(Fragment (&b)[Mma_tile::VALID_MMAS_N], int ki) { + static_assert(Mma_tile::MMAS_K == 2 || Mma_tile::MMAS_K == 3 || Mma_tile::MMAS_K == 4 || + Mma_tile::MMAS_K == 6, + ""); +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { + // The amount of row packing. + enum { ROW_PACKING = Base::BYTES_PER_ROW / Base::BYTES_PER_ROW_BEFORE_PACKING }; + + // Skip N paddings + if (ni < Mma_tile::VALID_MMAS_N) { + // Jump by 8*16 rows per K but account for packing. + int row = ki * 16 * Cta_tile::WARPS_K / ROW_PACKING; + + // We emulate the Turing logic, which loads the data using LDSM.MT88.2: + // uint2 tmp; + // fmha::ldsmt(tmp, this->smem_ + this->smem_read_offset_ + row * Base::BYTES_PER_ROW); + // this call fetches two 8x16 matrices, stacked on top of each other + + // we fake LDSM.MT88.2, with 2 LDS.128 and a shuffle: + // - T 0 - T 7 have the smem addresses of LDSM 0, each should do 16B loads + // - T 8 - T15 have the smem addresses of LSDM 1, each should do 16B loads + int const lane = threadIdx.x % Cta_tile::THREADS_PER_WARP; + + uint4 tmp16{0, 0, 0, 0}; // 16B + + if (lane < 16) { + fmha::lds(tmp16, this->smem_ + this->smem_read_offset_ + row * Base::BYTES_PER_ROW); + } + + uint16_t* tmp16c = reinterpret_cast(&tmp16); // 8x2B: we move pairs + + uint2 tmp; // 2*4B + uint16_t* t = reinterpret_cast(&tmp); // 4x2B + + int const src_col = lane / 4; // 0 - 7 + int const src_row = lane % 4 * 2; + +// We have to shuffle the values to distribute them in the warp. +#pragma unroll + for (int it = 0; it < 8; it++) { + uint16_t val, x, y; + val = tmp16c[it]; + x = __shfl_sync(uint32_t(-1), val, src_row + 0); + __syncwarp(); + y = __shfl_sync(uint32_t(-1), val, src_row + 1); + __syncwarp(); + + if (src_col == it) { + t[0] = x; + t[1] = y; + } + val = tmp16c[it]; + x = __shfl_sync(uint32_t(-1), val, src_row + 8); + __syncwarp(); + y = __shfl_sync(uint32_t(-1), val, src_row + 9); + __syncwarp(); + + if (src_col == it) { + t[2] = x; + t[3] = y; + } + } + + // Repack the elements. With LDSM.T, thread 0 has the following elements in its two + // regs: + // + // R0 = [(n=0 k=0), (n=1 k=0), (n=0 k=8), (n=1 k=8)] + // R1 = [(n=0 k=1), (n=1 k=1), (n=0 k=9), (n=1 k=9)] + // + // We want to repack the values as: + // + // R0 = [(n=0 k=0), (n=0 k=1), (n=0 k=8), (n=0 k=9)] + // R1 = [(n=1 k=0), (n=1 k=1), (n=1 k=8), (n=1 k=9)] + // + // Since that this layout corresponds to the layout of elements in the Fragment_a from + // P. + + swizzle_rows(b[ni].reg(0), b[ni].reg(1), tmp.x, tmp.y); + } + + // Move to the next N position. + if (Mma_tile::MMAS_N == 4) { + this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 1 : 3); + } else { + assert(false); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_v + : public Smem_tile_v_imma::Base { + // The traits class. + using Traits = Turing_imma_int8_int32_traits; + // The base class. + using Base = typename Smem_tile_v_imma::Base; + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + // The fragment. + using Fragment = fmha::Fragment_b; + + // The size of a single LDS in bytes. + enum { BYTES_PER_LDS = 16 }; + + // Ctor. + inline __device__ Smem_tile_v(void* smem, int tidx) : Base(smem, tidx) { + // Warps. + enum { WARPS_M = Cta_tile::WARPS_M }; + + enum { WARPS_N = Cta_tile::WARPS_N }; + + enum { WARPS_K = Cta_tile::WARPS_K }; + + // Determine the config. + enum { WARPS_2x1x2 = WARPS_M == 2 && WARPS_N == 1 && WARPS_K == 2 }; + + enum { WARPS_1x1x8 = WARPS_M == 1 && WARPS_N == 1 && WARPS_K == 8 }; + + enum { WARPS_1x1x4 = WARPS_M == 1 && WARPS_N == 1 && WARPS_K == 4 }; + + // The row/col read by the thread. + int read_row, read_col; + + // SEQLEN == 128 and N == 32. + if (WARPS_2x1x2 && Cta_tile::N == 16) { + read_row = (tidx & 0x40) / 32 + (tidx & 0x08) / 8; + read_col = (tidx & 0x07); + + // SEQLEN == 128 and N == 32. + } else if (WARPS_2x1x2 && Cta_tile::N == 32) { + read_row = (tidx & 0x40) / 16 + (tidx & 0x0c) / 4; + read_col = (tidx & 0x03) * 2 + (tidx & 0x04) / 4; + + // SEQLEN == 128 and N == 64. + } else if (WARPS_2x1x2 && Cta_tile::N == 64) { + read_row = (tidx & 0x40) / 8 + (tidx & 0x0e) / 2; + read_col = (tidx & 0x01) * 4 + (tidx & 0x06) / 2; + + // SEQLEN == 256, 512 and N == 16. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 16) { + read_row = (tidx & 0xe0) / 16 + (tidx & 0x08) / 8; + read_col = (tidx & 0x07); + + // SEQLEN == 256, 512 and N == 32. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 32) { + read_row = (tidx & 0xe0) / 8 + (tidx & 0x0c) / 4; + read_col = (tidx & 0x03) * 2 + (tidx & 0x04) / 4; + + // SEQLEN == 256, 384, 512 and N == 64. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 64) { + read_row = (tidx & 0xe0) / 4 + (tidx & 0x0e) / 2; + read_col = (tidx & 0x01) * 4 + (tidx & 0x06) / 2; + + // Not supported. + } else { + assert(false); + } + + // The shared memory offset. + this->smem_read_offset_ = read_row * Base::BYTES_PER_ROW + read_col * BYTES_PER_LDS; + } + + // Load from shared memory. + inline __device__ void load(Fragment (&b)[Mma_tile::VALID_MMAS_N], int ki) { + static_assert(Mma_tile::MMAS_K == 2 || Mma_tile::MMAS_K == 3 || Mma_tile::MMAS_K == 4 || + Mma_tile::MMAS_K == 6 || Mma_tile::MMAS_K == 8, + ""); +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { + // The amount of row packing. + enum { ROW_PACKING = Base::BYTES_PER_ROW / Base::BYTES_PER_ROW_BEFORE_PACKING }; + + // Skip N paddings + if (ni < Mma_tile::VALID_MMAS_N) { + // Jump by 8*16 rows per K but account for packing. + int row = ki * 16 * Cta_tile::WARPS_K / ROW_PACKING; + + // Load the data using LDSM.MT88.2. + uint2 tmp; + fmha::ldsmt(tmp, this->smem_ + this->smem_read_offset_ + row * Base::BYTES_PER_ROW); + + // Repack the elements. With LDSM.T, thread 0 has the following elements in its two + // regs: + // + // R0 = [(n=0 k=0), (n=1 k=0), (n=0 k=8), (n=1 k=8)] + // R1 = [(n=0 k=1), (n=1 k=1), (n=0 k=9), (n=1 k=9)] + // + // We want to repack the values as: + // + // R0 = [(n=0 k=0), (n=0 k=1), (n=0 k=8), (n=0 k=9)] + // R1 = [(n=1 k=0), (n=1 k=1), (n=1 k=8), (n=1 k=9)] + // + // Since that this layout corresponds to the layout of elements in the Fragment_a from + // P. + + swizzle_rows(b[ni].reg(0), b[ni].reg(1), tmp.x, tmp.y); + + // b[ni].reg(0) = tmp.x; + // b[ni].reg(1)= tmp.y; + } + + // Move to the next N position. + if (Mma_tile::MMAS_N == 1) { + // Noop. + } else if (Mma_tile::MMAS_N == 2) { + this->smem_read_offset_ ^= BYTES_PER_LDS; + } else if (Mma_tile::MMAS_N == 4) { + this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 1 : 3); + } else { + assert(false); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_v_ampere_hmma + : public Smem_tile_v_hmma::Base { + // The base class. + using Base = typename Smem_tile_v_hmma::Base; + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + // The fragment. + using Fragment = fmha::Fragment_b; + + // The size of a single LDS in bytes. + enum { BYTES_PER_LDS = 16 }; + + // Ctor. + inline __device__ Smem_tile_v_ampere_hmma(void* smem, int tidx) : Base(smem, tidx) { + // Warps. + enum { WARPS_M = Cta_tile::WARPS_M }; + + enum { WARPS_N = Cta_tile::WARPS_N }; + + enum { WARPS_K = Cta_tile::WARPS_K }; + + // Determine the config. + enum { WARPS_2x1x2 = WARPS_M == 2 && WARPS_N == 1 && WARPS_K == 2 }; + + enum { WARPS_1x1x8 = WARPS_M == 1 && WARPS_N == 1 && WARPS_K == 8 }; + + enum { WARPS_1x1x4 = WARPS_M == 1 && WARPS_N == 1 && WARPS_K == 4 }; + + // Flash Attention uses WARPS_4x1x1 + enum { WARPS_4x1x1 = WARPS_M == 4 && WARPS_N == 1 && WARPS_K == 1 }; + + // The row/col read by the thread. + int read_row, read_col; + + // SEQLEN == 128 and N == 16. + if (WARPS_2x1x2 && Cta_tile::N == 16) { + read_row = (tidx & 0x40) / 16 + (tidx & 0x0c) / 4; + read_col = (tidx & 0x03) * 2 + (tidx & 0x04) / 4; + + // SEQLEN == 128 and N == 32. + } else if (WARPS_2x1x2 && Cta_tile::N == 32) { + read_row = (tidx & 0x40) / 8 + (tidx & 0x0e) / 2; + read_col = (tidx & 0x01) * 4 + (tidx & 0x06) / 2; + + // SEQLEN == 128 and N == 64/128/256. + } else if (WARPS_2x1x2 && (Cta_tile::N == 64 || Cta_tile::N == 128 || Cta_tile::N == 256)) { + read_row = (tidx & 0x40) / 4 + (tidx & 0x0f); + read_col = (tidx & 0x07); + + // SEQLEN == 256, 512 and N == 16. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 16) { + read_row = (tidx & 0xe0) / 8 + (tidx & 0x0c) / 4; + read_col = (tidx & 0x03) * 2 + (tidx & 0x04) / 4; + + // SEQLEN == 256, 512 and N == 32. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 32) { + read_row = (tidx & 0xe0) / 4 + (tidx & 0x0e) / 2; + read_col = (tidx & 0x01) * 4 + (tidx & 0x06) / 2; + + // SEQLEN == 256, 384, 512 and N == 64/128/256. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && + (Cta_tile::N == 64 || Cta_tile::N == 128 || Cta_tile::N == 256)) { + read_row = (tidx & 0xe0) / 2 + (tidx & 0x0f); + read_col = (tidx & 0x07); + + // ANY SEQLEN and N == 16. + } else if (WARPS_4x1x1 && Cta_tile::N == 16) { + read_row = (tidx & 0x0c) / 4; + read_col = (tidx & 0x03) * 2 + (tidx & 0x04) / 4; + + // ANY SEQLEN and N == 32. + } else if (WARPS_4x1x1 && Cta_tile::N == 32) { + read_row = (tidx & 0x0e) / 2; + read_col = (tidx & 0x01) * 4 + (tidx & 0x06) / 2; + + // ANY SEQLEN and N == 64/128/256. + } else if (WARPS_4x1x1 && (Cta_tile::N == 64 || Cta_tile::N == 128 || Cta_tile::N == 256 || + Cta_tile::N == 512)) { + read_row = (tidx & 0x0f); + read_col = (tidx & 0x07); + + // Not supported. + } else { + assert(false); + } + + // The 2nd HMMA. + read_col ^= (tidx & 0x10) / 16; + + // The shared memory offset. + this->smem_read_offset_ = read_row * Base::BYTES_PER_ROW + read_col * BYTES_PER_LDS; + } + + // Load from shared memory. + inline __device__ void load(Fragment (&b)[Mma_tile::VALID_MMAS_N], int ki) { +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { + // The amount of row packing. + enum { ROW_PACKING = Base::BYTES_PER_ROW / Base::BYTES_PER_ROW_BEFORE_PACKING }; + + // Jump by 16 * #warps row. Account for the packing. + int row = ki * 16 * Cta_tile::WARPS_K / ROW_PACKING; + + // Skip N paddings + if (ni < Mma_tile::VALID_MMAS_N) { + // Jump by 16 * #warps row. Account for the packing. + int row = ki * 16 * Cta_tile::WARPS_K / ROW_PACKING; + + // Load the data using LDSM.MT88.2. + uint4 tmp; + fmha::ldsmt(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + + row * Base::BYTES_PER_ROW); + b[ni].reg(0) = tmp.x; + b[ni].reg(1) = tmp.y; + b[ni].reg(2) = tmp.z; + b[ni].reg(3) = tmp.w; + } + + // Move the pointer for the next ni. I expect the compiler to not recompute those. + static_assert(Mma_tile::MMAS_N <= 64, ""); + if (Mma_tile::MMAS_N >= 32 && ni % 16 == 15) { + this->smem_read_offset_ ^= 31 * BYTES_PER_LDS * 2; + } else if (Mma_tile::MMAS_N >= 16 && ni % 8 == 7) { + this->smem_read_offset_ ^= 15 * BYTES_PER_LDS * 2; + } else if (Mma_tile::MMAS_N >= 8 && ni % 4 == 3) { + this->smem_read_offset_ ^= 7 * BYTES_PER_LDS * 2; + } else if (Mma_tile::MMAS_N >= 4 && ni % 2 == 1) { + this->smem_read_offset_ ^= 3 * BYTES_PER_LDS * 2; + } else if (Mma_tile::MMAS_N >= 2) { + this->smem_read_offset_ ^= 1 * BYTES_PER_LDS * 2; + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_v + : public Smem_tile_v_ampere_hmma { + // The base class. + using Base = Smem_tile_v_ampere_hmma; + + // Ctor. + inline __device__ Smem_tile_v(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_v + : public Smem_tile_v_ampere_hmma { + // The base class. + using Base = Smem_tile_v_ampere_hmma; + + // Ctor. + inline __device__ Smem_tile_v(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_v + : public Smem_tile_v_ampere_hmma { + // The base class. + using Base = Smem_tile_v_ampere_hmma; + + // Ctor. + inline __device__ Smem_tile_v(void* smem, int tidx) : Base(smem, tidx) {} +}; + +template +struct Smem_tile_v_ampere_8bit_mma + : public Smem_tile_v_imma::Base { + // The base class. + using Base = typename Smem_tile_v_imma::Base; + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + // The fragment. + using Fragment = fmha::Fragment_b; + + // The size of a single LDS in bytes. + enum { BYTES_PER_LDS = 16 }; + + // Ctor. + inline __device__ Smem_tile_v_ampere_8bit_mma(void* smem, int tidx) : Base(smem, tidx) { + // Warps. + enum { WARPS_M = Cta_tile::WARPS_M }; + + enum { WARPS_N = Cta_tile::WARPS_N }; + + enum { WARPS_K = Cta_tile::WARPS_K }; + + // Determine the config. + enum { WARPS_2x1x2 = WARPS_M == 2 && WARPS_N == 1 && WARPS_K == 2 }; + + enum { WARPS_1x1x8 = WARPS_M == 1 && WARPS_N == 1 && WARPS_K == 8 }; + + enum { WARPS_1x1x4 = WARPS_M == 1 && WARPS_N == 1 && WARPS_K == 4 }; + + enum { WARPS_4x1x1 = WARPS_M == 4 && WARPS_N == 1 && WARPS_K == 1 }; + + // The row/col read by the thread. + int read_row, read_col; + + // SEQLEN == 128 and N == 16. + if (WARPS_2x1x2 && Cta_tile::N == 16) { + read_row = (tidx & 0x40) / 32 + (tidx & 0x08) / 8; + read_col = (tidx & 0x07); + + // SEQLEN == 128 and N == 32. + } else if (WARPS_2x1x2 && Cta_tile::N == 32) { + read_row = (tidx & 0x40) / 16 + (tidx & 0x0c) / 4; + read_col = (tidx & 0x03) * 2 + (tidx & 0x04) / 4; + + // SEQLEN == 128 and N == 64. + } else if (WARPS_2x1x2 && Cta_tile::N == 64) { + read_row = (tidx & 0x40) / 8 + (tidx & 0x0e) / 2; + read_col = (tidx & 0x01) * 4 + (tidx & 0x06) / 2; + + // SEQLEN == 256, 512 and N == 16. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 16) { + read_row = (tidx & 0xe0) / 16 + (tidx & 0x08) / 8; + read_col = (tidx & 0x07); + + // SEQLEN == 256, 512 and N == 32. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 32) { + read_row = (tidx & 0xe0) / 8 + (tidx & 0x0c) / 4; + read_col = (tidx & 0x03) * 2 + (tidx & 0x04) / 4; + + // SEQLEN == 256, 384, 512 and N == 64. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 64) { + read_row = (tidx & 0xe0) / 4 + (tidx & 0x0e) / 2; + read_col = (tidx & 0x01) * 4 + (tidx & 0x06) / 2; + } else if (WARPS_4x1x1 && Cta_tile::N == 32) { + read_row = (tidx % 32) / 4; + read_col = read_row % 2 + (tidx % 4) * 2; + } else if (WARPS_4x1x1 && Cta_tile::N == 64) { + read_row = (tidx % 32) / 2; + read_col = read_row % 4 + (tidx & 0x01) * 4; + } else if (WARPS_4x1x1 && (Cta_tile::N >= 128)) { + read_row = tidx % 32; + read_col = tidx % 8; + + // Not supported. + } else { + assert(false); + } + + // The shared memory offset. + this->smem_read_offset_ = read_row * Base::BYTES_PER_ROW + read_col * BYTES_PER_LDS; + } + + // Load from shared memory. + inline __device__ void load(Fragment (&b)[Mma_tile::VALID_MMAS_N], int ki) { +// static_assert(Mma_tile::MMAS_K == 3 || Mma_tile::MMAS_K == 2 || Mma_tile::MMAS_K == 1, ""); +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { + // The amount of row packing. + enum { ROW_PACKING = Base::BYTES_PER_ROW / Base::BYTES_PER_ROW_BEFORE_PACKING }; + + // // Make sure we do not end up with weird values :) + // static_assert(Cta_tile::WARPS_K % ROW_PACKING == 0, ""); + + // Skip N paddings + if (ni < Mma_tile::VALID_MMAS_N) { + // Jump by 8*32 rows per K but account for the fact that we have packing. + int row_0 = (ki * 32 + 0 * 16) * Cta_tile::WARPS_K / ROW_PACKING; + int row_1 = (ki * 32 + 1 * 16) * Cta_tile::WARPS_K / ROW_PACKING; + + // Load the data using LDSM.MT88.2. + uint32_t smem = this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_; + uint2 tmp_0; + fmha::ldsmt(tmp_0, smem + row_0 * Base::BYTES_PER_ROW); + + // Load the next two values. + uint2 tmp_1 = make_uint2(0u, 0u); + if constexpr (Cta_tile::K > 16) { + fmha::ldsmt(tmp_1, smem + row_1 * Base::BYTES_PER_ROW); + } + + // Repack the elements. With LDSM.T, thread 0 has the following elements in its 4 regs: + // + // R0 = [(n=0 k= 0), (n=1 k= 0), (n=0 k= 1), (n=1 k= 1)] + // R1 = [(n=0 k= 8), (n=1 k= 8), (n=0 k= 9), (n=1 k= 9)] + // R2 = [(n=0 k=128), (n=1 k=128), (n=0 k=129), (n=1 k=129)] + // R3 = [(n=0 k=136), (n=1 k=136), (n=0 k=137), (n=1 k=137)] + // + // We want to repack the values as: + // + // R0 = [(n=0 k= 0), (n=0 k= 1), (n=0 k= 8), (n=0 k= 9)] + // R1 = [(n=0 k=128), (n=0 k=129), (n=0 k=136), (n=0 k=137)] + // R2 = [(n=1 k= 0), (n=1 k= 1), (n=1 k= 8), (n=1 k= 9)] + // R3 = [(n=1 k=128), (n=1 k=129), (n=1 k=136), (n=1 k=137)] + // + // Since this layout corresponds to the layout of elements in the Fragment_a from P. + + swizzle_rows(b[ni].reg(0), b[ni].reg(2), tmp_0.x, tmp_0.y); + swizzle_rows(b[ni].reg(1), b[ni].reg(3), tmp_1.x, tmp_1.y); + } + + // Move to the next N position. + if (Mma_tile::MMAS_N >= 32 && ni % 16 == 15) { + this->smem_read_offset_ ^= 31 * BYTES_PER_LDS; + } else if (Mma_tile::MMAS_N >= 16 && ni % 8 == 7) { + this->smem_read_offset_ ^= 15 * BYTES_PER_LDS; + } else if (Mma_tile::MMAS_N >= 8 && ni % 4 == 3) { + this->smem_read_offset_ ^= 7 * BYTES_PER_LDS; + } else if (Mma_tile::MMAS_N >= 4 && ni % 2 == 1) { + this->smem_read_offset_ ^= 3 * BYTES_PER_LDS; + } else if (Mma_tile::MMAS_N >= 2) { + this->smem_read_offset_ ^= 1 * BYTES_PER_LDS; + } else { + assert(false); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_v + : public Smem_tile_v_ampere_8bit_mma { + // The base class. + using Base = + Smem_tile_v_ampere_8bit_mma; + + // Ctor. + inline __device__ Smem_tile_v(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_v + : public Smem_tile_v_ampere_8bit_mma { + // The base class. + using Base = Smem_tile_v_ampere_8bit_mma; + + // Ctor. + inline __device__ Smem_tile_v(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_v + : public Smem_tile_v_ampere_8bit_mma { + // The base class. + using Base = Smem_tile_v_ampere_8bit_mma; + + // Ctor. + inline __device__ Smem_tile_v(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/softmax.h b/csrc/fmha_v2/fmha/softmax.h new file mode 100644 index 0000000000..68ecea49b9 --- /dev/null +++ b/csrc/fmha_v2/fmha/softmax.h @@ -0,0 +1,3964 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include + +#include "fmha/fragment.h" +#include "fmha/utils.h" + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Sum_ { + enum { IS_SUM = 1 }; + + static inline __device__ float apply(float x, float y) { return x + y; } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Max_ { + enum { IS_SUM = 0 }; + + static inline __device__ float apply(float x, float y) { return fmaxf(x, y); } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ float apply_exp_(float x, float max) { + return isinf(x) ? 0.f : __expf(x - max); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +inline __device__ float apply_exp_<2>(float x, float max) { + return __expf(x - max); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ float get_alibi_head_scaling_factor(int const in_head_id, + AlibiParams const& params) { + int const head_id = params.head_idx_offset + in_head_id; + if (head_id < params.h_pow_2) { + // 2^(head_id * -8 / h) + return exp2f((head_id + 1) * 2 * params.alibi_neg4_div_h) * params.scale_after_alibi; + } else { + // 1,3,5... etc + float const adjusted_head_id = 2 * (head_id - params.h_pow_2) + 1; + // 2^(adjusted_head_id * -4 / h) + return exp2f(adjusted_head_id * params.alibi_neg4_div_h) * params.scale_after_alibi; + ; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct ReadType { + using T = float; +}; + +template <> +struct ReadType<4> { + using T = float; +}; + +template <> +struct ReadType<8> { + using T = float2; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_reduce { + // Helper class to distribute MMA tiles reduced over rows per warp over quads. + + // The Mma tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of MMAs in M/N dimensions. + enum { MMAS_M = Mma_tile::MMAS_M }; + + enum { MMAS_N = Mma_tile::MMAS_N }; + + enum { WARPS_M = Cta_tile::WARPS_M }; + + enum { WARPS_N = Cta_tile::WARPS_N }; + + static constexpr int ROWS = WARPS_M * MMAS_M * 16; + static constexpr int COLS = WARPS_N; + static constexpr int ROWS_PER_XOR_PATTERN = (COLS == 8) ? 4 : 8; + static constexpr int BYTES_PER_TILE = ROWS * COLS * sizeof(float); + static constexpr int ELTS_PER_TILE = ROWS * COLS; + + static constexpr int THREADS_PER_GROUP = Kernel_traits::Gmem_tile_o::THREADS_PER_ROW; + static constexpr int ROWS_PER_WARP = 32 / THREADS_PER_GROUP; + static constexpr int LOOPS = Kernel_traits::Gmem_tile_o::LOOPS; + + using read_t = typename ReadType::T; + + __device__ inline Smem_tile_reduce(float* smem_, int const tidx) { + int lane = tidx % 32; + int warp = tidx / 32; + + int warp_m = warp % WARPS_M; + int warp_n = warp / WARPS_M; + + qid_ = lane % 4; + int qp = lane / 4; + + // Swizzle the column to avoid 2-fold bank conflicts when we have 8 warps. + // This won't affect reading as we assume commutative reduction ops. + int const col = warp_n ^ (qp / ROWS_PER_XOR_PATTERN); + smem_write_ = &smem_[warp_m * 16 * MMAS_M * WARPS_N + qp * WARPS_N + col]; + smem_read_ = &reinterpret_cast(smem_)[warp_m * 16 * MMAS_M * 4 + qp * 4 + qid_]; + } + + __device__ inline void store(float (&frag)[2 * MMAS_M]) { + if (qid_ == 0) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; mi++) { + int offset = mi * 16 * WARPS_N; + smem_write_[offset + 0 * 8 * WARPS_N] = frag[mi * 2 + 0]; + smem_write_[offset + 1 * 8 * WARPS_N] = frag[mi * 2 + 1]; + } + } + } + + __device__ inline void load(read_t (&frag)[2 * MMAS_M]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; mi++) { + int offset = mi * 16 * 4; + frag[mi * 2 + 0] = smem_read_[offset + 0 * 8 * 4]; + frag[mi * 2 + 1] = smem_read_[offset + 1 * 8 * 4]; + } + } + + int qid_; + float* smem_write_; + read_t* smem_read_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax_base { + // The Mma tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of MMAs in M/N dimensions. + enum { MMAS_M = Mma_tile::MMAS_M }; + + enum { MMAS_N = Mma_tile::MMAS_N }; + + // The number of groups of warp such that we have at most 4 warps writing consecutive elements. + enum { GROUPS = fmha::Div_up::VALUE }; + + // The number of elements that we are going to store per row. + enum { ELEMENTS_PER_ROW = Cta_tile::WARPS_N / GROUPS }; + + // The number of rows. + enum { ROWS = Cta_tile::M * GROUPS }; + + // The total number of elements. + enum { ELEMENTS = ROWS * ELEMENTS_PER_ROW }; + + // If shared memory is used + enum { USE_SHARED_MEMORY = Cta_tile::WARPS_N > 1 }; + + // DEBUG. + static_assert(ELEMENTS == Cta_tile::M * Cta_tile::WARPS_N, ""); + + // END OF DEBUG. + + // The number of rows per thread. + enum { ROWS_PER_THREAD = MMAS_M * 2 }; + + // Ctor. + template + inline __device__ Softmax_base(Params const& params, void* smem, int bidb, int tidx) + : smem_(reinterpret_cast(smem)), tidx_(tidx) { + // Extract the position in the warp. + int warp = tidx / Cta_tile::THREADS_PER_WARP; + int lane = tidx % Cta_tile::THREADS_PER_WARP; + + // Decompose the warp index into M and N. + int warp_m = warp % Cta_tile::WARPS_M; + int warp_n = warp / Cta_tile::WARPS_M; + + // Decompose the warp-n index into group/position-inside-the-group. + int warp_g = warp_n / ELEMENTS_PER_ROW; + int warp_i = warp_n % ELEMENTS_PER_ROW; + + // The location written by the threads. + int write_row = warp_g * Cta_tile::M + warp_m * Mma_tile::M_PER_MMA + lane / 4; + int write_col = warp_i; + + // Assemble the write pointer. + smem_write_ = &smem_[write_row * ELEMENTS_PER_ROW + write_col]; + + // Assemble the read pointer. + smem_read_ = &smem_[warp_m * Mma_tile::M_PER_MMA + lane / 4]; + } + + // Apply mask before softmax. Use 1 byte per MMA distributed as 2x4. + template + inline __device__ void apply_mask(Mask const& mask) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ii = 0; ii < 2; ++ii) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int jj = 0; jj < 4; ++jj) { + if (!mask.is_valid(mi, ni, ii, jj)) { + elt_[2 * mi + ii][4 * ni + jj] = -FLT_MAX; + } + } + } + } + } + } + + template + inline __device__ void apply_mask_alibi(Mask const& mask, int head_id, + AlibiParams const& alibi_params) { + // 'if constexpr' because ALiBi is only defined for causal masks + if constexpr (Kernel_traits::CAUSAL_MASK) { + float m = get_alibi_head_scaling_factor(head_id, alibi_params); +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ii = 0; ii < 2; ++ii) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int jj = 0; jj < 4; ++jj) { + int row, col; + mask.get_row_col(row, col, mi, ni, ii, jj); + if (mask.is_valid(row, col)) { + // Since softmax is shift invariant, + // it is sufficient just to use the column as the multiplier + elt_[2 * mi + ii][4 * ni + jj] = + elt_[2 * mi + ii][4 * ni + jj] * alibi_params.scale_after_alibi + + m * (col + alibi_params.sequence_pos_offset); + } else { + elt_[2 * mi + ii][4 * ni + jj] = -FLT_MAX; + } + } + } + } + } + } else { + __builtin_unreachable(); + } + } + + // Apply the mask to unpacked data. + inline __device__ void apply_mask(uint32_t const (&packed_mask)[MMAS_M]) { + // This code works only if we have MMAS_N <= 4. + static_assert(MMAS_N <= 4, ""); + + // Expand the mask. + int mask[MMAS_M * 2][MMAS_N * 4]; +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + mask[2 * mi + 0][4 * ni + 0] = packed_mask[mi] & (1u << (8 * ni + 0)); + mask[2 * mi + 0][4 * ni + 1] = packed_mask[mi] & (1u << (8 * ni + 1)); + mask[2 * mi + 1][4 * ni + 0] = packed_mask[mi] & (1u << (8 * ni + 2)); + mask[2 * mi + 1][4 * ni + 1] = packed_mask[mi] & (1u << (8 * ni + 3)); + mask[2 * mi + 0][4 * ni + 2] = packed_mask[mi] & (1u << (8 * ni + 4)); + mask[2 * mi + 0][4 * ni + 3] = packed_mask[mi] & (1u << (8 * ni + 5)); + mask[2 * mi + 1][4 * ni + 2] = packed_mask[mi] & (1u << (8 * ni + 6)); + mask[2 * mi + 1][4 * ni + 3] = packed_mask[mi] & (1u << (8 * ni + 7)); + } + } + +// Apply the mask. +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N * 4; ++ni) { + if (!mask[mi][ni]) { + elt_[mi][ni] = -FLT_MAX; + } + } + } + } + + // Mask the elements that are outside the the sequence length. + inline __device__ void apply_mask(int const actual_seqlen) { + // The warp/lane decomposition. + int const warp = threadIdx.x / Cta_tile::THREADS_PER_WARP; + int const lane = threadIdx.x % Cta_tile::THREADS_PER_WARP; + + // The warp in the n dimension. + int const warp_n = warp / Cta_tile::WARPS_M; + // The position within a quad. + int const quad_lane = lane % 4; + +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + // Determine the position in the sequence. + int const offset = ni * Mma_tile::N_PER_MMA_PER_CTA + warp_n * 16; + if (offset + 0 + 2 * quad_lane >= actual_seqlen) { + elt_[mi][4 * ni + 0] = -FLT_MAX; // 0 + } + if (offset + 1 + 2 * quad_lane >= actual_seqlen) { + elt_[mi][4 * ni + 1] = -FLT_MAX; // 1 + } + if (offset + 8 + 2 * quad_lane >= actual_seqlen) { + elt_[mi][4 * ni + 2] = -FLT_MAX; // 8 + } + if (offset + 9 + 2 * quad_lane >= actual_seqlen) { + elt_[mi][4 * ni + 3] = -FLT_MAX; // 9 + } + } + } + } + + // Apply the exp to all the elements. + inline __device__ void apply_exp(float const max) { +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N * 4; ++ni) { + elt_[mi][ni] = apply_exp_(elt_[mi][ni], max); + } + } + } + + // Apply the exp to all the elements. + inline __device__ void apply_scale_exp(float const (&max)[MMAS_M * 2], float scale_bmm1) { +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N * 4; ++ni) { + elt_[mi][ni] = apply_exp_(scale_bmm1 * elt_[mi][ni], max[mi]); + } + } + } + + // Apply the exp to all the elements. + inline __device__ void apply_exp(float const (&max)[MMAS_M * 2]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N * 4; ++ni) { + elt_[mi][ni] = apply_exp_(elt_[mi][ni], max[mi]); + } + } + } + + // Do a warp-wide reduction. + template + inline __device__ void reduce_Nx1(float (&dst)[MMAS_M * 2]) { +#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) + if (Functor::IS_SUM) { +// Apply the summation inside the thread. +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { + float tmp[2] = {0.f, 0.f}; +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + tmp[0] += elt_[mi][4 * ni + 0] + elt_[mi][4 * ni + 1]; + tmp[1] += elt_[mi][4 * ni + 2] + elt_[mi][4 * ni + 3]; + } + dst[mi] = tmp[0] + tmp[1]; + } + } else +#endif // defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) + { +// Apply the functor for each row inside a thread. +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { + dst[mi] = elt_[mi][0]; +#pragma unroll + for (int ni = 1; ni < MMAS_N * 4; ++ni) { + dst[mi] = Functor::apply(dst[mi], elt_[mi][ni]); + } + } + } + +// Apply the functor for each row inside each group of 4 threads. +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { + dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 1)); + dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 2)); + } + } + + // Do a CTA-wide reduction. + template + inline __device__ float reduce_2x2() { + float dst[MMAS_M * 2]; +#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) + if (Functor::IS_SUM) { +// Apply the summation inside the thread. +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { + // Pair-wise adds in the different threads of the reference code (x+y and z+w). + float a_01 = elt_[mi][0] + elt_[mi][1]; + float a_45 = elt_[mi][4] + elt_[mi][5]; + + //// tmp[0/1] += __shfl_xor(2) in the reference code. + a_01 += elt_[mi][2] + elt_[mi][3]; + a_45 += elt_[mi][6] + elt_[mi][7]; + + //// tmp[0/1] += __shfl_xor(8) in the reference code. + a_01 += a_45; + + if (MMAS_N >= 3) { + float a_89 = elt_[mi][8] + elt_[mi][9]; + a_89 += elt_[mi][10] + elt_[mi][11]; + if (MMAS_N == 4) { + float a_cd = elt_[mi][12] + elt_[mi][13]; + a_cd += elt_[mi][14] + elt_[mi][15]; + a_89 += a_cd; + } + a_01 += a_89; + } + dst[mi] = a_01; + } + } else +#endif // defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) + { +// Apply the functor for each row inside a thread. +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { + dst[mi] = elt_[mi][0]; +#pragma unroll + for (int ni = 1; ni < MMAS_N * 4; ++ni) { + dst[mi] = Functor::apply(dst[mi], elt_[mi][ni]); + } + } + } + +// Apply the functor for each row inside each group of 4 threads. +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { + dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 1)); + dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 2)); + } + +// Store the different values. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + if (tidx_ % 4 == 0) { + smem_write_[(mi * Mma_tile::M_PER_MMA_PER_CTA + 0) * ELEMENTS_PER_ROW] = dst[2 * mi + 0]; + smem_write_[(mi * Mma_tile::M_PER_MMA_PER_CTA + 8) * ELEMENTS_PER_ROW] = dst[2 * mi + 1]; + } + } + + // Make sure the values are in shared memory. + __syncthreads(); + + // Load 2 values (one for each warp). + float2 tmp = reinterpret_cast(smem_)[tidx_]; + + // Compute the reduction of those 2 values in a binary-tree fashion. + return Functor::apply(tmp.x, tmp.y); + } + + // Do a CTA-wide reduction. + template + inline __device__ float reduce_1x4() { + float dst[MMAS_M * 2]; +#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) + if (Functor::IS_SUM) { +// Apply the summation inside the thread. +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { + float tmp[2] = {0.f, 0.f}; +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + tmp[0] += elt_[mi][4 * ni + 0] + elt_[mi][4 * ni + 1]; + tmp[1] += elt_[mi][4 * ni + 2] + elt_[mi][4 * ni + 3]; + } + dst[mi] = tmp[0] + tmp[1]; + } + } else +#endif // defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) + { +// Apply the functor for each row inside a thread. +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { + dst[mi] = elt_[mi][0]; +#pragma unroll + for (int ni = 1; ni < MMAS_N * 4; ++ni) { + dst[mi] = Functor::apply(dst[mi], elt_[mi][ni]); + } + } + } + +// Apply the functor for each row inside each group of 4 threads. +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { + dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 1)); + dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 2)); + } + +// Store the different values. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + if (tidx_ % 4 == 0) { + smem_write_[(mi * Mma_tile::M_PER_MMA_PER_CTA + 0) * ELEMENTS_PER_ROW] = dst[2 * mi + 0]; + smem_write_[(mi * Mma_tile::M_PER_MMA_PER_CTA + 8) * ELEMENTS_PER_ROW] = dst[2 * mi + 1]; + } + } + + // Make sure the values are in shared memory. + __syncthreads(); + + // Load 8 values (one for each warp). The /8 corresponds to /(4*2) where 4 is from the + // float4. + float4 tmp[1]; + if (tidx_ < Cta_tile::M) { + tmp[0] = reinterpret_cast(&smem_[0 * ELEMENTS / 2])[tidx_]; + } + + // Compute the reduction of those 8 values in a binary-tree fashion. + tmp[0].x = Functor::apply(tmp[0].x, tmp[0].y); + tmp[0].z = Functor::apply(tmp[0].z, tmp[0].w); + tmp[0].x = Functor::apply(tmp[0].x, tmp[0].z); + + // Return the final reduction. + return tmp[0].x; + } + + // Do a CTA-wide reduction. + template + inline __device__ float reduce_1x8() { + float dst[MMAS_M * 2]; +#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) + if (Functor::IS_SUM) { + // Apply the summation inside the thread. + float tmp[MMAS_M * 2][2]; +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { + tmp[mi][0] = 0.f; + tmp[mi][1] = 0.f; +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + tmp[mi][0] += elt_[mi][4 * ni + 0]; + tmp[mi][0] += elt_[mi][4 * ni + 1]; + tmp[mi][1] += elt_[mi][4 * ni + 2]; + tmp[mi][1] += elt_[mi][4 * ni + 3]; + } + dst[mi] = tmp[mi][0] + tmp[mi][1]; + } + } else +#endif // defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) + { +// Apply the functor for each row inside a thread. +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { + dst[mi] = elt_[mi][0]; +#pragma unroll + for (int ni = 1; ni < MMAS_N * 4; ++ni) { + dst[mi] = Functor::apply(dst[mi], elt_[mi][ni]); + } + } + } + +// Apply the functor for each row inside each group of 4 threads. +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { + dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 1)); + dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 2)); + } + +// Store the different values. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + if (tidx_ % 4 == 0) { + smem_write_[(mi * Mma_tile::M_PER_MMA_PER_CTA + 0) * ELEMENTS_PER_ROW] = dst[2 * mi + 0]; + smem_write_[(mi * Mma_tile::M_PER_MMA_PER_CTA + 8) * ELEMENTS_PER_ROW] = dst[2 * mi + 1]; + } + } + + // Make sure the values are in shared memory. + __syncthreads(); + + // Load 8 values (one for each warp). The /8 corresponds to /(4*2) where 4 is from the + // float4. + float4 tmp[2]; + if (tidx_ < Cta_tile::M) { + tmp[0] = reinterpret_cast(&smem_[0 * ELEMENTS / 2])[tidx_]; + tmp[1] = reinterpret_cast(&smem_[1 * ELEMENTS / 2])[tidx_]; + } + + // Compute the reduction of those 8 values in a binary-tree fashion. + tmp[0].x = Functor::apply(tmp[0].x, tmp[0].y); + tmp[0].z = Functor::apply(tmp[0].z, tmp[0].w); + tmp[1].x = Functor::apply(tmp[1].x, tmp[1].y); + tmp[1].z = Functor::apply(tmp[1].z, tmp[1].w); + tmp[0].x = Functor::apply(tmp[0].x, tmp[0].z); + tmp[1].x = Functor::apply(tmp[1].x, tmp[1].z); + tmp[0].x = Functor::apply(tmp[0].x, tmp[1].x); + + // Return the result. + return tmp[0].x; + } + + // Do a CTA-wide reduction. + template + inline __device__ float reduce_() { + // The result of the reduction. Threads 0..Cta_tile::M-1 own a single row value. + float red = 0.f; + + // SEQLEN == 128. + if (Cta_tile::WARPS_M == 2 && Cta_tile::WARPS_N == 2) { + red = reduce_2x2(); + + // SEQLEN == 256. + } else if (Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 4) { + red = reduce_1x4(); + + // SEQLEN == 384. + } else if (Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 8) { + red = reduce_1x8(); + + // Not supported. + } else { + assert(false); + } + + return red; + } + + // Finalize the reduction. + inline __device__ void shuffle(float (&dst)[MMAS_M * 2], float red) { + // Store the value back to shared memory. + if (tidx_ < Cta_tile::M) { + smem_[tidx_] = red; + } + + // Make sure the data is in shared memory. + __syncthreads(); + +// Finally read the values. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + dst[2 * mi + 0] = smem_read_[mi * Mma_tile::M_PER_MMA_PER_CTA + 0]; + dst[2 * mi + 1] = smem_read_[mi * Mma_tile::M_PER_MMA_PER_CTA + 8]; + } + + // Make sure the data is in shared memory. + __syncthreads(); + } + + // Do a CTA-wide reduction. + template + inline __device__ void reduce(float (&dst)[MMAS_M * 2]) { + // NOTE: 1 warp along reduce direction, no syncs + if (Cta_tile::WARPS_N == 1) { + reduce_Nx1(dst); + } else { + // The result of the reduction. Threads 0..Cta_tile::M-1 own a single row value. + float red = reduce_(); + + // Make sure we can write to shared memory. + __syncthreads(); + + // Finalize the reduction. + shuffle(dst, red); + } + } + + // Scale all the elements. + inline __device__ void scale(float const (&sum)[MMAS_M * 2]) { + // Precompute the inverse sum to normalize. Without -use_fast_math, it makes a huge deal. + float inv_sum[MMAS_M * 2]; +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { + inv_sum[mi] = (sum[mi] == 0.f || sum[mi] != sum[mi]) ? 1.f : 1.f / sum[mi]; + } + +// Update the values. +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N * 4; ++ni) { + elt_[mi][ni] *= inv_sum[mi]; + } + } + } + + // Shared memory for the CTA-wide reduction. + float *smem_, *smem_write_, *smem_read_; + // The current thread index. + int tidx_; + // The elements. + float elt_[MMAS_M * 2][MMAS_N * 4]; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax_hmma : public Softmax_base { + // The base class. + using Base = Softmax_base; + + // The MMAs. + enum { MMAS_M = Base::MMAS_M }; + + enum { MMAS_N = Base::MMAS_N }; + + // Whether we need to skip the softmax due to the sliding-window attention + // Otherwise, we will get NANs as those tokens are all masked out. + enum { SLIDING_WINDOW_ATTENTION = Kernel_traits::SLIDING_WINDOW_ATTENTION }; + + // Use BMM1 softcapping scale or not. + enum { ENABLE_BMM1_SOFTCAPPING_SCALE = Kernel_traits::ENABLE_BMM1_SOFTCAPPING_SCALE }; + + // The accumulators. + using Accumulator = fmha::Fragment_accumulator; + + // Softmax dst data_type (BMM2 input) + using Dst_type = typename Traits::A_type; + + // Ctor. + template + inline __device__ Softmax_hmma(Params const& params, void* smem, int bidb, int tidx) + : Base(params, smem, bidb, tidx), + params_scale_bmm1_(params.scale_bmm1), + params_softcapping_scale_bmm1_(params.softcapping_scale_bmm1) {} + + // Store the tile after softmax. + template + inline __device__ void store(Gmem_tile& gmem_tile) { + Accumulator acc[MMAS_M][MMAS_N]; +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + // The elements. + float tmp_00 = this->elt_[2 * mi + 0][4 * ni + 0]; + float tmp_01 = this->elt_[2 * mi + 0][4 * ni + 1]; + float tmp_02 = this->elt_[2 * mi + 0][4 * ni + 2]; + float tmp_03 = this->elt_[2 * mi + 0][4 * ni + 3]; + float tmp_10 = this->elt_[2 * mi + 1][4 * ni + 0]; + float tmp_11 = this->elt_[2 * mi + 1][4 * ni + 1]; + float tmp_12 = this->elt_[2 * mi + 1][4 * ni + 2]; + float tmp_13 = this->elt_[2 * mi + 1][4 * ni + 3]; + + // Transform to accumulators. + acc[mi][ni].reg(0) = fmha::float2_to_16bit_2(tmp_00, tmp_01); + acc[mi][ni].reg(1) = fmha::float2_to_16bit_2(tmp_10, tmp_11); + acc[mi][ni].reg(2) = fmha::float2_to_16bit_2(tmp_02, tmp_03); + acc[mi][ni].reg(3) = fmha::float2_to_16bit_2(tmp_12, tmp_13); + } + } + + // Delegate to the gmem tile to store. + gmem_tile.store(acc); + } + + // Convert from FP16 fragments to floats. + inline __device__ void unpack(Accumulator const (&acc)[MMAS_M][MMAS_N]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + // Normalize the values, and clamp to finite half. + uint32_t acc_0 = satfinite_h2(hmul2(acc[mi][ni].reg(0), params_scale_bmm1_)); + uint32_t acc_1 = satfinite_h2(hmul2(acc[mi][ni].reg(1), params_scale_bmm1_)); + uint32_t acc_2 = satfinite_h2(hmul2(acc[mi][ni].reg(2), params_scale_bmm1_)); + uint32_t acc_3 = satfinite_h2(hmul2(acc[mi][ni].reg(3), params_scale_bmm1_)); + + // Extract the values as floats. + half2_to_float2(this->elt_[2 * mi + 0][4 * ni + 0], this->elt_[2 * mi + 0][4 * ni + 1], + acc_0); + half2_to_float2(this->elt_[2 * mi + 1][4 * ni + 0], this->elt_[2 * mi + 1][4 * ni + 1], + acc_1); + half2_to_float2(this->elt_[2 * mi + 0][4 * ni + 2], this->elt_[2 * mi + 0][4 * ni + 3], + acc_2); + half2_to_float2(this->elt_[2 * mi + 1][4 * ni + 2], this->elt_[2 * mi + 1][4 * ni + 3], + acc_3); + + // Attention logit softcapping scale. + // 1.0f / softcapping_scale has been fused to scale_bmm1. + if constexpr (ENABLE_BMM1_SOFTCAPPING_SCALE) { + this->elt_[2 * mi + 0][4 * ni + 0] = + params_softcapping_scale_bmm1_ * __tanhf(this->elt_[2 * mi + 0][4 * ni + 0]); + this->elt_[2 * mi + 0][4 * ni + 1] = + params_softcapping_scale_bmm1_ * __tanhf(this->elt_[2 * mi + 0][4 * ni + 1]); + this->elt_[2 * mi + 1][4 * ni + 0] = + params_softcapping_scale_bmm1_ * __tanhf(this->elt_[2 * mi + 1][4 * ni + 0]); + this->elt_[2 * mi + 1][4 * ni + 1] = + params_softcapping_scale_bmm1_ * __tanhf(this->elt_[2 * mi + 1][4 * ni + 1]); + this->elt_[2 * mi + 0][4 * ni + 2] = + params_softcapping_scale_bmm1_ * __tanhf(this->elt_[2 * mi + 0][4 * ni + 2]); + this->elt_[2 * mi + 0][4 * ni + 3] = + params_softcapping_scale_bmm1_ * __tanhf(this->elt_[2 * mi + 0][4 * ni + 3]); + this->elt_[2 * mi + 1][4 * ni + 2] = + params_softcapping_scale_bmm1_ * __tanhf(this->elt_[2 * mi + 1][4 * ni + 2]); + this->elt_[2 * mi + 1][4 * ni + 3] = + params_softcapping_scale_bmm1_ * __tanhf(this->elt_[2 * mi + 1][4 * ni + 3]); + } + } + } + } + + // Apply the exp to all the elements. + // Need to make sure the results are zero when all elts are -FLT_MAX + // as it is possible that all tokens are masked out. + template + inline __device__ void apply_exp_with_mask(float const (&max)[MMAS_M * 2]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { + float max_val = APPLY_MASK && max[mi] == -FLT_MAX ? 0.f : max[mi]; +#pragma unroll + for (int ni = 0; ni < MMAS_N * 4; ++ni) { + this->elt_[mi][ni] = expf(this->elt_[mi][ni] - max_val); + } + } + } + + // The scaling factor. + uint32_t const params_scale_bmm1_; + float const params_softcapping_scale_bmm1_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_helper {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_helper { + // The traits. + using Traits = fmha::Volta_imma_int8_int32_traits; + // The fragment A. + using Fragment_a = fmha::Fragment_a; + // The accumulator. + using Accumulator = fmha::Fragment_accumulator; + + // Load a 2x4 array from registers. + static inline __device__ void load(int32_t (&dst)[2][4], Accumulator const& src) { + dst[0][0] = src.elt(0); + dst[0][1] = src.elt(1); + dst[0][2] = src.elt(2); + dst[0][3] = src.elt(3); + dst[1][0] = src.elt(4); + dst[1][1] = src.elt(5); + dst[1][2] = src.elt(6); + dst[1][3] = src.elt(7); + } + + // Store to an accumulator. + static inline __device__ void store(Accumulator& dst, uint32_t const (&src)[2][4]) { + dst.reg(0) = src[0][0]; + dst.reg(1) = src[0][1]; + dst.reg(2) = src[0][2]; + dst.reg(3) = src[0][3]; + dst.reg(4) = src[1][0]; + dst.reg(5) = src[1][1]; + dst.reg(6) = src[1][2]; + dst.reg(7) = src[1][3]; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_helper { + // The traits. + using Traits = fmha::Turing_imma_int8_int32_traits; + // The fragment A. + using Fragment_a = fmha::Fragment_a; + // The accumulator. + using Accumulator = fmha::Fragment_accumulator; + + // Load a 2x4 array from registers. + static inline __device__ void load(int32_t (&dst)[2][4], Accumulator const& src) { + dst[0][0] = src.elt(0); + dst[0][1] = src.elt(1); + dst[0][2] = src.elt(2); + dst[0][3] = src.elt(3); + dst[1][0] = src.elt(4); + dst[1][1] = src.elt(5); + dst[1][2] = src.elt(6); + dst[1][3] = src.elt(7); + } + + // Store to an accumulator. + static inline __device__ void store(Accumulator& dst, uint32_t const (&src)[2][4]) { + dst.reg(0) = src[0][0]; + dst.reg(1) = src[0][1]; + dst.reg(2) = src[0][2]; + dst.reg(3) = src[0][3]; + dst.reg(4) = src[1][0]; + dst.reg(5) = src[1][1]; + dst.reg(6) = src[1][2]; + dst.reg(7) = src[1][3]; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_helper { + // The traits. + using Traits = fmha::Ampere_imma_int8_int32_traits; + // The fragment A. + using Fragment_a = fmha::Fragment_a; + // The accumulator. + using Accumulator = fmha::Fragment_accumulator; + + // Load a 2x4 array from registers. + static inline __device__ void load(int32_t (&dst)[2][4], Accumulator const& src) { + dst[0][0] = src.elt(0); + dst[0][1] = src.elt(1); + dst[0][2] = src.elt(4); + dst[0][3] = src.elt(5); + dst[1][0] = src.elt(2); + dst[1][1] = src.elt(3); + dst[1][2] = src.elt(6); + dst[1][3] = src.elt(7); + } + + // Store to an accumulator. + static inline __device__ void store(Accumulator& dst, uint32_t const (&src)[2][4]) { + dst.reg(0) = src[0][0]; + dst.reg(1) = src[0][1]; + dst.reg(4) = src[0][2]; + dst.reg(5) = src[0][3]; + dst.reg(2) = src[1][0]; + dst.reg(3) = src[1][1]; + dst.reg(6) = src[1][2]; + dst.reg(7) = src[1][3]; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax_imma : public Softmax_base { + // The base class. + using Base = Softmax_base; + // The Mma tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The MMAs. + enum { MMAS_M = Base::MMAS_M }; + + enum { MMAS_N = Base::MMAS_N }; + + // The accumulators. + using Accumulator = fmha::Fragment_accumulator; + // The fragment. + using Fragment_a = fmha::Fragment_a; + + // The dst type + using Dst_type = typename Traits::A_type; + + // Ctor. + template + inline __device__ Softmax_imma(Params const& params, void* smem, int bidb, int tidx) + : Base(params, smem, bidb, tidx), + params_scale_bmm1_(params.scale_bmm1), + params_scale_softmax_(params.scale_softmax) {} + + // Store the tile after softmax. + template + inline __device__ void store(Gmem_tile& gmem_tile) { + float const scale = reinterpret_cast(params_scale_softmax_); + Accumulator acc[MMAS_M][MMAS_N]; +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + // Scale the FP32 elements. + uint32_t tmp[2][4]; +#pragma unroll + for (int mj = 0; mj < 2; ++mj) { +#pragma unroll + for (int nj = 0; nj < 4; ++nj) { + float f = this->elt_[2 * mi + mj][4 * ni + nj] * scale; + asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(tmp[mj][nj]) : "f"(f)); + } + } + + // Convert to int8 and store. + Fragment_helper::store(acc[mi][ni], tmp); + } + } + + // Delegate to the gmem tile to store. + gmem_tile.store(acc); + } + + // Convert from accumulators to floats. + inline __device__ void unpack(Accumulator const (&acc)[MMAS_M][MMAS_N]) { + float const scale = reinterpret_cast(params_scale_bmm1_); +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + // Load the values from the accumulator's registers. + int32_t tmp[2][4]; + Fragment_helper::load(tmp, acc[mi][ni]); + +// Convert to FP32 and scale. +#pragma unroll + for (int mj = 0; mj < 2; ++mj) { +#pragma unroll + for (int nj = 0; nj < 4; ++nj) { +#if defined(USE_I2F_EMULATION_TRICK) + float f = reinterpret_cast(tmp[mj][nj]); + this->elt_[2 * mi + mj][4 * ni + nj] = (f - FP32_I2F_MAGIC_NUMBER) * scale; +#else + this->elt_[2 * mi + mj][4 * ni + nj] = static_cast(tmp[mj][nj]) * scale; +#endif // defined(USE_I2F_EMULATION_TRICK) + } + } + } + } + } + + // Repack. We could use store/load to match the Smem_tile API. (shared by Ampere IMMA and Ada + // QMMA) + template + inline __device__ void pack(Fragment_a_ (&dst)[K][M]) { + // We pack N 16x16 acc tiles into K 16x32 tiles for A. + // In the 16x16 tile, a thread owns 4 elts per row (4 regs). + // In the 16x32 A tile, a thread owns 8 elts per row (2 regs). + // Hence we have to pack with a 2:1 ratio. + // For N = 1, K is 1: pack 4 values into dst reg 0. Set reg 1 to 0. + // For N = 2, K is 1: pack 8 values into dst regs 0, 1. + // For N = 3, K is 2: pack 12 values into dst regs (0,0), (0,1), (1,0). Set (1,1) to 0. + // For N = 4, K is 2: pack 16 values into dst regs (0,0), (0,1), (1,0), (1,1) + // For N = 5, K is 3: pack 20 values into dst regs (0,0), (0,1), (1,0), (1,1), (2,0). Set (2,1) + // to 0. For N = 6, K is 3: pack 24 values into dst regs (0,0), (0,1), (1,0), (1,1), (2,0), + // (2,1) + + static_assert(K == 3 || K == 2 || K == 1, ""); + + float const scale = reinterpret_cast(this->params_scale_softmax_); + +#pragma unroll + for (int mi = 0; mi < M; ++mi) { + // 1st row - 12 elements per row. + float tmp_00 = this->elt_[2 * mi + 0][0] * scale; + float tmp_01 = this->elt_[2 * mi + 0][1] * scale; + float tmp_02 = this->elt_[2 * mi + 0][2] * scale; + float tmp_03 = this->elt_[2 * mi + 0][3] * scale; + float tmp_04 = this->elt_[2 * mi + 0][4] * scale; + float tmp_05 = this->elt_[2 * mi + 0][5] * scale; + float tmp_06 = this->elt_[2 * mi + 0][6] * scale; + float tmp_07 = this->elt_[2 * mi + 0][7] * scale; + float tmp_08 = this->elt_[2 * mi + 0][8] * scale; + float tmp_09 = this->elt_[2 * mi + 0][9] * scale; + float tmp_0a = this->elt_[2 * mi + 0][10] * scale; + float tmp_0b = this->elt_[2 * mi + 0][11] * scale; + + // 2nd row - 12 elements per row. + float tmp_20 = this->elt_[2 * mi + 1][0] * scale; + float tmp_21 = this->elt_[2 * mi + 1][1] * scale; + float tmp_22 = this->elt_[2 * mi + 1][2] * scale; + float tmp_23 = this->elt_[2 * mi + 1][3] * scale; + float tmp_24 = this->elt_[2 * mi + 1][4] * scale; + float tmp_25 = this->elt_[2 * mi + 1][5] * scale; + float tmp_26 = this->elt_[2 * mi + 1][6] * scale; + float tmp_27 = this->elt_[2 * mi + 1][7] * scale; + float tmp_28 = this->elt_[2 * mi + 1][8] * scale; + float tmp_29 = this->elt_[2 * mi + 1][9] * scale; + float tmp_2a = this->elt_[2 * mi + 1][10] * scale; + float tmp_2b = this->elt_[2 * mi + 1][11] * scale; + + // Pack the first 12 elements to 6 registers of 2 fragments. + dst[0][mi].reg(0) = fmha::float4_to_8bitx4(tmp_00, tmp_01, tmp_02, tmp_03); + dst[0][mi].reg(1) = fmha::float4_to_8bitx4(tmp_20, tmp_21, tmp_22, tmp_23); + dst[0][mi].reg(2) = fmha::float4_to_8bitx4(tmp_04, tmp_05, tmp_06, tmp_07); + dst[0][mi].reg(3) = fmha::float4_to_8bitx4(tmp_24, tmp_25, tmp_26, tmp_27); + if (K > 1) { + dst[1][mi].reg(0) = fmha::float4_to_8bitx4(tmp_08, tmp_09, tmp_0a, tmp_0b); + dst[1][mi].reg(1) = fmha::float4_to_8bitx4(tmp_28, tmp_29, tmp_2a, tmp_2b); + } + + if (Mma_tile::MMAS_N == 6) { + float tmp_0c = this->elt_[2 * mi + 0][12] * scale; + float tmp_0d = this->elt_[2 * mi + 0][13] * scale; + float tmp_0e = this->elt_[2 * mi + 0][14] * scale; + float tmp_0f = this->elt_[2 * mi + 0][15] * scale; + float tmp_10 = this->elt_[2 * mi + 0][16] * scale; + float tmp_11 = this->elt_[2 * mi + 0][17] * scale; + float tmp_12 = this->elt_[2 * mi + 0][18] * scale; + float tmp_13 = this->elt_[2 * mi + 0][19] * scale; + float tmp_14 = this->elt_[2 * mi + 0][20] * scale; + float tmp_15 = this->elt_[2 * mi + 0][21] * scale; + float tmp_16 = this->elt_[2 * mi + 0][22] * scale; + float tmp_17 = this->elt_[2 * mi + 0][23] * scale; + + float tmp_2c = this->elt_[2 * mi + 1][12] * scale; + float tmp_2d = this->elt_[2 * mi + 1][13] * scale; + float tmp_2e = this->elt_[2 * mi + 1][14] * scale; + float tmp_2f = this->elt_[2 * mi + 1][15] * scale; + float tmp_30 = this->elt_[2 * mi + 1][16] * scale; + float tmp_31 = this->elt_[2 * mi + 1][17] * scale; + float tmp_32 = this->elt_[2 * mi + 1][18] * scale; + float tmp_33 = this->elt_[2 * mi + 1][19] * scale; + float tmp_34 = this->elt_[2 * mi + 1][20] * scale; + float tmp_35 = this->elt_[2 * mi + 1][21] * scale; + float tmp_36 = this->elt_[2 * mi + 1][22] * scale; + float tmp_37 = this->elt_[2 * mi + 1][23] * scale; + + dst[1][mi].reg(2) = fmha::float4_to_8bitx4(tmp_0c, tmp_0d, tmp_0e, tmp_0f); + dst[1][mi].reg(3) = fmha::float4_to_8bitx4(tmp_2c, tmp_2d, tmp_2e, tmp_2f); + dst[2][mi].reg(0) = fmha::float4_to_8bitx4(tmp_10, tmp_11, tmp_12, tmp_13); + dst[2][mi].reg(1) = fmha::float4_to_8bitx4(tmp_30, tmp_31, tmp_32, tmp_33); + dst[2][mi].reg(2) = fmha::float4_to_8bitx4(tmp_14, tmp_15, tmp_16, tmp_17); + dst[2][mi].reg(3) = fmha::float4_to_8bitx4(tmp_34, tmp_35, tmp_36, tmp_37); + } else if (Mma_tile::MMAS_N == 4) { + // SEQLEN == 128. + float tmp_0c = this->elt_[2 * mi + 0][12] * scale; + float tmp_0d = this->elt_[2 * mi + 0][13] * scale; + float tmp_0e = this->elt_[2 * mi + 0][14] * scale; + float tmp_0f = this->elt_[2 * mi + 0][15] * scale; + + float tmp_1c = this->elt_[2 * mi + 1][12] * scale; + float tmp_1d = this->elt_[2 * mi + 1][13] * scale; + float tmp_1e = this->elt_[2 * mi + 1][14] * scale; + float tmp_1f = this->elt_[2 * mi + 1][15] * scale; + + dst[1][mi].reg(2) = fmha::float4_to_8bitx4(tmp_0c, tmp_0d, tmp_0e, tmp_0f); + dst[1][mi].reg(3) = fmha::float4_to_8bitx4(tmp_1c, tmp_1d, tmp_1e, tmp_1f); + + // SEQLEN == 384 or SEQLEN == 256. + } else if (Mma_tile::MMAS_N == 3 || Mma_tile::MMAS_N == 2) { + // TODO added second OR term for ampere imma s=256: correct? + dst[1][mi].reg(2) = 0u; + dst[1][mi].reg(3) = 0u; + } else if (Mma_tile::MMAS_N == 1) { + dst[0][mi].reg(2) = 0u; + dst[0][mi].reg(3) = 0u; + + // Not implemented. + } else { + assert(false); + } + } + } + + // The scaling factors. + uint32_t const params_scale_bmm1_, params_scale_softmax_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax_qmma : public Softmax_imma {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax_qmma + : public Softmax_imma { + // The Traits + using Traits = fmha::Ada_qmma_e4m3_fp32_traits; + // The base class. + using Base = Softmax_imma; + + // The MMAs. + enum { MMAS_M = Base::MMAS_M }; + + enum { MMAS_N = Base::MMAS_N }; + + // The accumulators. + using Accumulator = fmha::Fragment_accumulator; + + // Ctor. + template + inline __device__ Softmax_qmma(Params const& params, void* smem, int bidb, int tidx) + : Base(params, smem, bidb, tidx), + params_scale_bmm1_(params.scale_bmm1_d ? *params.scale_bmm1_d : params.scale_bmm1), + params_scale_softmax_(params.scale_softmax) {} + + // Store the tile after softmax. + template + inline __device__ void store(Gmem_tile& gmem_tile) { + float const scale = reinterpret_cast(params_scale_softmax_); + Accumulator acc[MMAS_M][MMAS_N]; +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + // scale + acc[mi][ni].ele(0) = this->elt_[2 * mi + 0][4 * ni + 0] * scale; + acc[mi][ni].ele(1) = this->elt_[2 * mi + 0][4 * ni + 1] * scale; + acc[mi][ni].ele(4) = this->elt_[2 * mi + 0][4 * ni + 2] * scale; + acc[mi][ni].ele(5) = this->elt_[2 * mi + 0][4 * ni + 3] * scale; + acc[mi][ni].ele(2) = this->elt_[2 * mi + 1][4 * ni + 0] * scale; + acc[mi][ni].ele(3) = this->elt_[2 * mi + 1][4 * ni + 1] * scale; + acc[mi][ni].ele(6) = this->elt_[2 * mi + 1][4 * ni + 2] * scale; + acc[mi][ni].ele(7) = this->elt_[2 * mi + 1][4 * ni + 3] * scale; + } + } + + // Delegate to the gmem tile to store. + // TODO: need fp32 to fp8 conversion (move this to gmem_tile) + gmem_tile.store(acc); + } + + // Convert from accumulators to floats. + inline __device__ void unpack(Accumulator const (&acc)[MMAS_M][MMAS_N]) { + float const scale = reinterpret_cast(params_scale_bmm1_); +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + // Convert to FP32 and scale. + this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0) * scale; + this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1) * scale; + this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi][ni].elt(4) * scale; + this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi][ni].elt(5) * scale; + this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi][ni].elt(2) * scale; + this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3) * scale; + this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6) * scale; + this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7) * scale; + } + } + } + + template + inline __device__ void apply_exp_with_mask(float const (&max)[MMAS_M * 2]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { + float max_val = APPLY_MASK && max[mi] == -FLT_MAX + ? 0.f + : (max[mi] - logf(Traits::SOFTMAX_FP_QUANT_SCALE)); +#pragma unroll + for (int ni = 0; ni < MMAS_N * 4; ++ni) { + this->elt_[mi][ni] = expf(this->elt_[mi][ni] - max_val); + } + } + } + + // Pack the data to a fragment for the next GEMM. + template + inline __device__ void pack(Fragment_a (&dst)[K][M]) const { + float const scale = reinterpret_cast(this->params_scale_softmax_); + +// The canonical layout in K should be R0: [0,1,2,3] R2: [16,17,18,19] +// Note below that this is not possible with the register layout of the accumulator. +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ki = 0; ki < K; ++ki) { + // 1st row - 8 elements per row. + float tmp_00 = this->elt_[2 * mi + 0][8 * ki + 0] * scale; // + 0 + float tmp_01 = this->elt_[2 * mi + 0][8 * ki + 1] * scale; // + 1 + float tmp_02 = this->elt_[2 * mi + 0][8 * ki + 2] * scale; // + 8 + float tmp_03 = this->elt_[2 * mi + 0][8 * ki + 3] * scale; // + 9 + float tmp_04 = this->elt_[2 * mi + 0][8 * ki + 4] * scale; // +16 + float tmp_05 = this->elt_[2 * mi + 0][8 * ki + 5] * scale; // +17 + float tmp_06 = this->elt_[2 * mi + 0][8 * ki + 6] * scale; // +24 + float tmp_07 = this->elt_[2 * mi + 0][8 * ki + 7] * scale; // +25 + + // 2nd row - 4 elements per row. + float tmp_10 = this->elt_[2 * mi + 1][8 * ki + 0] * scale; // + 0 + float tmp_11 = this->elt_[2 * mi + 1][8 * ki + 1] * scale; // + 1 + float tmp_12 = this->elt_[2 * mi + 1][8 * ki + 2] * scale; // + 8 + float tmp_13 = this->elt_[2 * mi + 1][8 * ki + 3] * scale; // + 9 + float tmp_14 = this->elt_[2 * mi + 1][8 * ki + 4] * scale; // +16 + float tmp_15 = this->elt_[2 * mi + 1][8 * ki + 5] * scale; // +17 + float tmp_16 = this->elt_[2 * mi + 1][8 * ki + 6] * scale; // +24 + float tmp_17 = this->elt_[2 * mi + 1][8 * ki + 7] * scale; // +25 + + // Pack to 4 registers. + dst[ki][mi].reg(0) = fmha::float4_to_fp8x4(tmp_00, tmp_01, tmp_02, tmp_03); + dst[ki][mi].reg(1) = fmha::float4_to_fp8x4(tmp_10, tmp_11, tmp_12, tmp_13); + dst[ki][mi].reg(2) = fmha::float4_to_fp8x4(tmp_04, tmp_05, tmp_06, tmp_07); + dst[ki][mi].reg(3) = fmha::float4_to_fp8x4(tmp_14, tmp_15, tmp_16, tmp_17); + } + } + } + + // The scaling factors. + uint32_t const params_scale_bmm1_, params_scale_softmax_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax_qmma + : public Softmax_imma { + // The Traits + using Traits = fmha::Ada_qmma_e4m3_fp16_traits; + // The base class. + using Base = Softmax_imma; + + // The MMAs. + enum { MMAS_M = Base::MMAS_M }; + + enum { MMAS_N = Base::MMAS_N }; + + // The accumulators. + using Accumulator = fmha::Fragment_accumulator; + + // Ctor. + template + inline __device__ Softmax_qmma(Params const& params, void* smem, int bidb, int tidx) + : Base(params, smem, bidb, tidx), + params_scale_bmm1_(params.scale_bmm1), + params_scale_softmax_(params.scale_softmax) {} + + // Store the tile after softmax. + template + inline __device__ void store(Gmem_tile& gmem_tile) { + float const scale = reinterpret_cast(params_scale_softmax_); + Accumulator acc[MMAS_M][MMAS_N]; +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + // scale + acc[mi][ni].ele(0) = this->elt_[2 * mi + 0][4 * ni + 0] * scale; + acc[mi][ni].ele(1) = this->elt_[2 * mi + 0][4 * ni + 1] * scale; + acc[mi][ni].ele(4) = this->elt_[2 * mi + 0][4 * ni + 2] * scale; + acc[mi][ni].ele(5) = this->elt_[2 * mi + 0][4 * ni + 3] * scale; + acc[mi][ni].ele(2) = this->elt_[2 * mi + 1][4 * ni + 0] * scale; + acc[mi][ni].ele(3) = this->elt_[2 * mi + 1][4 * ni + 1] * scale; + acc[mi][ni].ele(6) = this->elt_[2 * mi + 1][4 * ni + 2] * scale; + acc[mi][ni].ele(7) = this->elt_[2 * mi + 1][4 * ni + 3] * scale; + } + } + + // Delegate to the gmem tile to store. + // TODO: need fp32 to fp8 conversion (move this to gmem_tile) + gmem_tile.store(acc); + } + + // Convert from accumulators to floats. + inline __device__ void unpack(Accumulator const (&acc)[MMAS_M][MMAS_N]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + // Convert to FP32 and scale. + float2* elt_ptr0 = reinterpret_cast(this->elt_[2 * mi + 0] + 4 * ni); + float2* elt_ptr1 = reinterpret_cast(this->elt_[2 * mi + 1] + 4 * ni); + elt_ptr0[0] = fmha::half2_to_float2(fmha::hmul2(acc[mi][ni].reg(0), params_scale_bmm1_)); + elt_ptr0[1] = fmha::half2_to_float2(fmha::hmul2(acc[mi][ni].reg(2), params_scale_bmm1_)); + elt_ptr1[0] = fmha::half2_to_float2(fmha::hmul2(acc[mi][ni].reg(1), params_scale_bmm1_)); + elt_ptr1[1] = fmha::half2_to_float2(fmha::hmul2(acc[mi][ni].reg(3), params_scale_bmm1_)); + } + } + } + + // The scaling factors. + uint32_t const params_scale_bmm1_, params_scale_softmax_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax { + // The traits class. + using Traits = fmha::Volta_hmma_fp16_traits; + // The Mma tile. + using Mma_tile = typename Traits::template Mma_tile; + // The accumulators. + using Accumulator = fmha::Fragment_accumulator; + // The fragment. + using Fragment_a = fmha::Fragment_a; + + // Softmax dst data_type (BMM2 input) + using Dst_type = typename Traits::A_type; + + // The number of MMAs in M/N dimensions. + enum { MMAS_M = Mma_tile::MMAS_M }; + + enum { MMAS_N = Mma_tile::MMAS_N }; + + // The number of groups of warp such that we have at most 2 warps writing consecutive elements. + enum { GROUPS = fmha::Div_up::VALUE }; + + // The number of elements that we are going to store per row. + enum { ELEMENTS_PER_ROW = Cta_tile::WARPS_N / GROUPS }; + + // The number of rows. + enum { ROWS = Cta_tile::M * GROUPS }; + + // The total number of elements. + enum { ELEMENTS = ROWS * ELEMENTS_PER_ROW }; + + // Use BMM1 softcapping scale or not. + enum { ENABLE_BMM1_SOFTCAPPING_SCALE = Kernel_traits::ENABLE_BMM1_SOFTCAPPING_SCALE }; + + // If shared memory is used + enum { USE_SHARED_MEMORY = Cta_tile::WARPS_N > 1 }; + + // The number of rows per thread. + enum { ROWS_PER_THREAD = MMAS_M }; + + // DEBUG. + static_assert(ELEMENTS == Cta_tile::M * Cta_tile::WARPS_N, ""); + + // END OF DEBUG. + + // Ctor. + template + inline __device__ Softmax(Params const& params, void* smem, int bidb, int tidx) + : params_scale_bmm1_(params.scale_bmm1), + params_softcapping_scale_bmm1_(params.softcapping_scale_bmm1), + smem_(reinterpret_cast(smem)), + tidx_(tidx) { + // Extract the position in the warp. + int warp = tidx / Cta_tile::THREADS_PER_WARP; + int lane = tidx % Cta_tile::THREADS_PER_WARP; + + // Decompose the warp index into M and N. + int warp_m = warp % Cta_tile::WARPS_M; + int warp_n = warp / Cta_tile::WARPS_M; + + // Decompose the warp-n index into group/position-inside-the-group. + int warp_g = warp_n / ELEMENTS_PER_ROW; + int warp_i = warp_n % ELEMENTS_PER_ROW; + + // The row written/read by the thread (threads i and i+8 are on the same row). + int row = (lane & 0x10) / 2 + (lane & 0x07); + + // The location written by the threads. + int write_row = warp_g * Cta_tile::M + warp_m * Mma_tile::M_PER_MMA + row; + int write_col = warp_i; + + // Assemble the write pointer. + smem_write_ = &smem_[write_row * ELEMENTS_PER_ROW + write_col]; + // Assemble the read pointer. + smem_read_ = &smem_[warp_m * Mma_tile::M_PER_MMA + row]; + } + + // Apply mask before softmax. Use 1 byte per MMA distributed as 1x8. + template + inline __device__ void apply_mask(Mask const& mask) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < 8; ++ii) { + if (!mask.is_valid(mi, ni, 0, ii)) { + elt_[mi][8 * ni + ii] = -FLT_MAX; + } + } + } + } + } + + template + inline __device__ void apply_mask_alibi(Mask const& mask, int head_id, + AlibiParams const& alibi_params) { + // 'if constexpr' because ALiBi is only defined for causal masks + if constexpr (Kernel_traits::CAUSAL_MASK) { + float m = get_alibi_head_scaling_factor(head_id, alibi_params); +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < 8; ++ii) { + int row, col; + mask.get_row_col(row, col, mi, ni, 0, ii); + if (mask.is_valid(row, col)) { + // Since softmax is shift invariant, + // it is sufficient just to use the column as the multiplier + elt_[mi][8 * ni + ii] = elt_[mi][8 * ni + ii] * alibi_params.scale_after_alibi + + m * (col + alibi_params.sequence_pos_offset); + } else { + elt_[mi][8 * ni + ii] = -FLT_MAX; + } + } + } + } + } else { + __builtin_unreachable(); + } + } + + // Apply the mask to unpacked data. + inline __device__ void apply_mask(uint32_t const (&packed_mask)[MMAS_M]) { + // This code works only if we have MMAS_N <= 4. + static_assert(MMAS_N <= 4, ""); + + // Expand the mask. + int mask[MMAS_M][MMAS_N * 8]; +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ii = 0; ii < MMAS_N * 8; ++ii) { + mask[mi][ii] = packed_mask[mi] & (1u << ii); + } + } + +// Apply the mask. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N * 8; ++ni) { + if (!mask[mi][ni]) { + elt_[mi][ni] = -FLT_MAX; + } + } + } + } + + // Mask the elements that are outside the the sequence length. + inline __device__ void apply_mask(int const seqlen) { + // The warp/lane decomposition. + int const warp = threadIdx.x / Cta_tile::THREADS_PER_WARP; + int const lane = threadIdx.x % Cta_tile::THREADS_PER_WARP; + + // The warp in the n dimension. + int const warp_n = warp / Cta_tile::WARPS_M; + // The base position within a quad. + int const offset = warp_n * 16 + (threadIdx.x & 0x08) / 2; + +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + // The position in the sequence. + int pos = offset + ni * Mma_tile::N_PER_MMA_PER_CTA; + + // Determine the position in the sequence. + if (pos + 0 >= seqlen) { + elt_[mi][8 * ni + 0] = -FLT_MAX; + } + if (pos + 1 >= seqlen) { + elt_[mi][8 * ni + 1] = -FLT_MAX; + } + if (pos + 2 >= seqlen) { + elt_[mi][8 * ni + 2] = -FLT_MAX; + } + if (pos + 3 >= seqlen) { + elt_[mi][8 * ni + 3] = -FLT_MAX; + } + if (pos + 8 >= seqlen) { + elt_[mi][8 * ni + 4] = -FLT_MAX; + } + if (pos + 9 >= seqlen) { + elt_[mi][8 * ni + 5] = -FLT_MAX; + } + if (pos + 10 >= seqlen) { + elt_[mi][8 * ni + 6] = -FLT_MAX; + } + if (pos + 11 >= seqlen) { + elt_[mi][8 * ni + 7] = -FLT_MAX; + } + } + } + } + + // Apply the exp to all the elements. + // Need to make sure the results are zero when all elts are -FLT_MAX + // as it is possible that all tokens are masked out. + template + inline __device__ void apply_exp_with_mask(float const (&max)[MMAS_M]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + float max_val = APPLY_MASK && max[mi] == -FLT_MAX ? 0.f : max[mi]; +#pragma unroll + for (int ni = 0; ni < MMAS_N * 8; ++ni) { + this->elt_[mi][ni] = expf(this->elt_[mi][ni] - max_val); + } + } + } + + // Apply the exp to all the elements. + inline __device__ void apply_exp(float const max) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N * 8; ++ni) { + elt_[mi][ni] = apply_exp_(elt_[mi][ni], max); + } + } + } + + // Apply the exp to all the elements. + inline __device__ void apply_exp(float const (&max)[MMAS_M]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N * 8; ++ni) { + elt_[mi][ni] = apply_exp_(elt_[mi][ni], max[mi]); + } + } + } + + // Pack the data to a fragment for the next GEMM. + template + inline __device__ void pack(Fragment_a (&dst)[K][M]) const { + static_assert(MMAS_M == M && MMAS_N == K, ""); +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ki = 0; ki < K; ++ki) { + // 8 elements per row. + float tmp_0 = this->elt_[mi][8 * ki + 0]; + float tmp_1 = this->elt_[mi][8 * ki + 1]; + float tmp_2 = this->elt_[mi][8 * ki + 2]; + float tmp_3 = this->elt_[mi][8 * ki + 3]; + float tmp_4 = this->elt_[mi][8 * ki + 4]; + float tmp_5 = this->elt_[mi][8 * ki + 5]; + float tmp_6 = this->elt_[mi][8 * ki + 6]; + float tmp_7 = this->elt_[mi][8 * ki + 7]; + + // Pack to 4 registers. + dst[ki][mi].reg(0) = fmha::float2_to_16bit_2(tmp_0, tmp_1); + dst[ki][mi].reg(1) = fmha::float2_to_16bit_2(tmp_2, tmp_3); + dst[ki][mi].reg(2) = fmha::float2_to_16bit_2(tmp_4, tmp_5); + dst[ki][mi].reg(3) = fmha::float2_to_16bit_2(tmp_6, tmp_7); + } + } + } + + // Do a CTA-wide reduction. + template + inline __device__ void reduce_Nx1(float (&dst)[MMAS_M]) { +#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) + if (Functor::IS_SUM) { +// Apply the summation inside the thread for each row. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // The thread local math in the reference code. + float sums[MMAS_N * 2]; +#pragma unroll + for (int ii = 0; ii < MMAS_N * 2; ++ii) { + sums[ii] = elt_[mi][4 * ii + 0]; + sums[ii] += elt_[mi][4 * ii + 1]; + sums[ii] += elt_[mi][4 * ii + 2]; + sums[ii] += elt_[mi][4 * ii + 3]; + } + +// Columns 0 and 8: __shfl( 2). +#pragma unroll + for (int ii = 0; ii < MMAS_N; ++ii) { + sums[2 * ii] += sums[2 * ii + 1]; + } + +// Columns 0 and 32: __shfl( 8). +#pragma unroll + for (int ii = 0; ii < MMAS_N / 2; ++ii) { // MMAS_N / 2 == 0 if MMAS_N <= 1. + sums[4 * ii] += sums[4 * ii + 2]; + } + + // Columns 0 and 64: __shfl(16). + if (MMAS_N == 3) { + sums[0] += sums[4]; + } else if (MMAS_N >= 4) { +#pragma unroll + for (int ii = 0; ii < MMAS_N / 4; ++ii) { // MMAS_N / 4 == 0 if MMAS_N <= 2. + sums[8 * ii] += sums[8 * ii + 4]; + } + } + + // Store the final value for that row. + dst[mi] = sums[0]; + } + } else +#endif // defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) + { +// Apply the functor for each row inside a thread. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + dst[mi] = elt_[mi][0]; +#pragma unroll + for (int ni = 1; ni < MMAS_N * 8; ++ni) { + dst[mi] = Functor::apply(dst[mi], elt_[mi][ni]); + } + } + } + +// Apply the functor for each row. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 8)); + } + } + + // Do a CTA-wide reduction. + template + inline __device__ float reduce_2x2() { + float dst[MMAS_M]; +#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) + if (Functor::IS_SUM) { +// Apply the summation inside the thread for each row. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // The thread local math in the reference code. + float sums[MMAS_N * 2]; +#pragma unroll + for (int ii = 0; ii < MMAS_N * 2; ++ii) { + sums[ii] = elt_[mi][4 * ii + 0]; + sums[ii] += elt_[mi][4 * ii + 1]; + sums[ii] += elt_[mi][4 * ii + 2]; + sums[ii] += elt_[mi][4 * ii + 3]; + } + +// Columns 0 and 8: __shfl( 2). +#pragma unroll + for (int ii = 0; ii < MMAS_N; ++ii) { + sums[2 * ii] += sums[2 * ii + 1]; + } + +// Columns 0 and 32: __shfl( 8). +#pragma unroll + for (int ii = 0; ii < MMAS_N / 2; ++ii) { // MMAS_N / 2 == 0 if MMAS_N <= 1. + sums[4 * ii] += sums[4 * ii + 2]; + } + + // Columns 0 and 64: __shfl(16). + if (MMAS_N == 3) { + sums[0] += sums[4]; + } else if (MMAS_N >= 4) { +#pragma unroll + for (int ii = 0; ii < MMAS_N / 4; ++ii) { // MMAS_N / 4 == 0 if MMAS_N <= 2. + sums[8 * ii] += sums[8 * ii + 4]; + } + } + + // Store the final value for that row. + dst[mi] = sums[0]; + } + } else +#endif // defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) + { +// Apply the functor for each row inside a thread. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + dst[mi] = elt_[mi][0]; +#pragma unroll + for (int ni = 1; ni < MMAS_N * 8; ++ni) { + dst[mi] = Functor::apply(dst[mi], elt_[mi][ni]); + } + } + } + +// Apply the functor for each row. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 8)); + } + +// Store the different values to shared memory. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + if (tidx_ % 16 < 8) { + smem_write_[mi * Mma_tile::M_PER_MMA_PER_CTA * ELEMENTS_PER_ROW] = dst[mi]; + } + } + + // Make sure the values are in shared memory. + __syncthreads(); + + // Load 2 values (one for each warp). + float2 tmp = reinterpret_cast(smem_)[tidx_]; + + // Compute the reduction of those 2 values in a binary-tree fashion. + return Functor::apply(tmp.x, tmp.y); + } + + // Do a CTA-wide reduction. + template + inline __device__ float reduce_1x4() { + float dst[MMAS_M]; +#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) + if (Functor::IS_SUM) { +// Apply the summation inside the thread for each row. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // The thread local math in the reference code. + float sums[MMAS_N * 2]; +#pragma unroll + for (int ii = 0; ii < MMAS_N * 2; ++ii) { + sums[ii] = elt_[mi][4 * ii + 0]; + sums[ii] += elt_[mi][4 * ii + 1]; + sums[ii] += elt_[mi][4 * ii + 2]; + sums[ii] += elt_[mi][4 * ii + 3]; + } + + // Columns 0 and 128 (the ref code uses a step of 128). Not needed if SEQLEN <= 128. + if (Cta_tile::N > 128) { +#pragma unroll + for (int ii = 0; ii < MMAS_N; ++ii) { + sums[ii] += sums[MMAS_N + ii]; + } + } + +// Columns 0 and 8: __shfl( 2). +#pragma unroll + for (int ii = 0; ii < MMAS_N; ++ii) { + sums[2 * ii] += sums[2 * ii + 1]; + } + +// Columns 0 and 64: __shfl(16). +#pragma unroll + for (int ii = 0; ii < MMAS_N / 2; ++ii) { // MMAS_N / 2 == 0 if MMAS_N <= 1. + sums[4 * ii] += sums[4 * ii + 2]; + } + + // Store the final value for that row. + dst[mi] = sums[0]; + } + } else +#endif // defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) + { +// Apply the functor for each row inside a thread. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + dst[mi] = elt_[mi][0]; +#pragma unroll + for (int ni = 1; ni < MMAS_N * 8; ++ni) { + dst[mi] = Functor::apply(dst[mi], elt_[mi][ni]); + } + } + } + +// Apply the functor for each row. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 8)); + } + +// Store the different values to shared memory. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + if (tidx_ % 16 < 8) { + smem_write_[mi * Mma_tile::M_PER_MMA_PER_CTA * ELEMENTS_PER_ROW] = dst[mi]; + } + } + + // Make sure the values are in shared memory. + __syncthreads(); + + // Load 4 values (one for each warp). + float2 tmp[2]; + if (tidx_ < Cta_tile::M) { + tmp[0] = reinterpret_cast(&smem_[0 * ELEMENTS / 2])[tidx_]; + tmp[1] = reinterpret_cast(&smem_[1 * ELEMENTS / 2])[tidx_]; + } + + // Compute the reduction of those 4 values in a binary-tree fashion. + tmp[0].x = Functor::apply(tmp[0].x, tmp[0].y); + tmp[1].x = Functor::apply(tmp[1].x, tmp[1].y); + tmp[0].x = Functor::apply(tmp[0].x, tmp[1].x); + + // Return the final reduction. + return tmp[0].x; + } + + // Do a CTA-wide reduction. + template + inline __device__ float reduce_1x8() { + float dst[MMAS_M]; +#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) + if (Functor::IS_SUM) { +// Apply the summation inside the thread for each row. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // The thread local math in the reference code. + float sums[MMAS_N * 2]; +#pragma unroll + for (int ii = 0; ii < MMAS_N * 2; ++ii) { + sums[ii] = elt_[mi][4 * ii + 0]; + sums[ii] += elt_[mi][4 * ii + 1]; + sums[ii] += elt_[mi][4 * ii + 2]; + sums[ii] += elt_[mi][4 * ii + 3]; + } + +// Columns 0 and 128 (the ref code uses a step of 128). Not needed if SEQLEN <= 128. +#pragma unroll + for (int ii = 1; ii < MMAS_N; ++ii) { + sums[0] += sums[2 * ii + 0]; + sums[1] += sums[2 * ii + 1]; + } + + // Columns 0 and 8: __shfl( 2). + dst[mi] = sums[0] + sums[1]; + } + } else +#endif // defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) + { +// Apply the functor for each row inside a thread. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + dst[mi] = elt_[mi][0]; +#pragma unroll + for (int ni = 1; ni < MMAS_N * 8; ++ni) { + dst[mi] = Functor::apply(dst[mi], elt_[mi][ni]); + } + } + } + +// Apply the functor for each row. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 8)); + } + +// Store the different values to shared memory. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + if (tidx_ % 16 < 8) { + smem_write_[mi * Mma_tile::M_PER_MMA_PER_CTA * ELEMENTS_PER_ROW] = dst[mi]; + } + } + + // Make sure the values are in shared memory. + __syncthreads(); + + // Load 8 values (one for each warp). + float2 tmp[4]; + if (tidx_ < Cta_tile::M) { + tmp[0] = reinterpret_cast(&smem_[0 * ELEMENTS / 4])[tidx_]; + tmp[1] = reinterpret_cast(&smem_[1 * ELEMENTS / 4])[tidx_]; + tmp[2] = reinterpret_cast(&smem_[2 * ELEMENTS / 4])[tidx_]; + tmp[3] = reinterpret_cast(&smem_[3 * ELEMENTS / 4])[tidx_]; + } + + // // DEBUG. + // if( tidx_ == 0 ) { + // #pragma unroll + // for( int ii = 0; ii < 4; ++ii ) { + // printf("tidx=%3d tmp[%d]=%8.3f %8.3f\n", tidx_, ii, tmp[ii].x, tmp[ii].y); + // } + // } + // // END OF DEBUG. + + // Compute the reduction of those 8 values in a binary-tree fashion. + tmp[0].x = Functor::apply(tmp[0].x, tmp[0].y); + tmp[1].x = Functor::apply(tmp[1].x, tmp[1].y); + tmp[2].x = Functor::apply(tmp[2].x, tmp[2].y); + tmp[3].x = Functor::apply(tmp[3].x, tmp[3].y); + + tmp[0].x = Functor::apply(tmp[0].x, tmp[1].x); + tmp[2].x = Functor::apply(tmp[2].x, tmp[3].x); + + tmp[0].x = Functor::apply(tmp[0].x, tmp[2].x); + + // Return the final reduction. + return tmp[0].x; + } + + // Do a CTA-wide reduction. + template + inline __device__ float reduce_() { + // The final reduction. + float red = 0.f; + + // SEQLEN == 128. + if (Cta_tile::WARPS_M == 2 && Cta_tile::WARPS_N == 2) { + red = reduce_2x2(); + + // SEQLEN == 256. + } else if (Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 4) { + red = reduce_1x4(); + + // SEQLEN == 256. + } else if (Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 8) { + red = reduce_1x8(); + + // Not supported. + } else { + assert(false); + } + + return red; + } + + // Finalize the reduction. + inline __device__ void shuffle(float (&dst)[MMAS_M], float red) { + // Store the value back to shared memory. + if (tidx_ < Cta_tile::M) { + smem_[tidx_] = red; + } + + // Make sure the data is in shared memory. + __syncthreads(); + +// Finally read the values. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + dst[mi] = smem_read_[mi * Mma_tile::M_PER_MMA_PER_CTA]; + } + + // Make sure we are done reading shared memory. + __syncthreads(); + } + + // Do a CTA-wide reduction. + template + inline __device__ void reduce(float (&dst)[MMAS_M]) { + // NOTE: 1 warp along reduce direction, no syncs + if (Cta_tile::WARPS_N == 1) { + reduce_Nx1(dst); + } else { + // The result of the reduction. Threads 0..Cta_tile::M-1 own a valid value. + float red = reduce_(); + + // Make sure we can write to shared memory. + __syncthreads(); + + // Finalize the reduction. + shuffle(dst, red); + } + } + + // Scale all the elements. + inline __device__ void scale(float const (&sum)[MMAS_M]) { + // Precompute the inverse sum to normalize. Without -use_fast_math, it makes a huge deal. + float inv_sum[MMAS_M]; +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + inv_sum[mi] = (sum[mi] == 0.f || sum[mi] != sum[mi]) ? 1.f : 1.f / sum[mi]; + } + +// Update the values. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N * 8; ++ni) { + elt_[mi][ni] *= inv_sum[mi]; + } + } + } + + // Store the tile after softmax. + template + inline __device__ void store(Gmem_tile& gmem_tile) { + Accumulator acc[MMAS_M][MMAS_N]; +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + // The elements. + float tmp_00 = this->elt_[mi][8 * ni + 0]; + float tmp_01 = this->elt_[mi][8 * ni + 1]; + float tmp_02 = this->elt_[mi][8 * ni + 2]; + float tmp_03 = this->elt_[mi][8 * ni + 3]; + float tmp_04 = this->elt_[mi][8 * ni + 4]; + float tmp_05 = this->elt_[mi][8 * ni + 5]; + float tmp_06 = this->elt_[mi][8 * ni + 6]; + float tmp_07 = this->elt_[mi][8 * ni + 7]; + + // Transform to accumulators. + acc[mi][ni].reg(0) = fmha::float2_to_16bit_2(tmp_00, tmp_01); + acc[mi][ni].reg(1) = fmha::float2_to_16bit_2(tmp_02, tmp_03); + acc[mi][ni].reg(2) = fmha::float2_to_16bit_2(tmp_04, tmp_05); + acc[mi][ni].reg(3) = fmha::float2_to_16bit_2(tmp_06, tmp_07); + } + } + + // Delegate to the gmem tile to store. + gmem_tile.store(acc); + } + + // Convert from FP16 fragments to floats. + inline __device__ void unpack(Accumulator const (&acc)[MMAS_M][MMAS_N]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + // Normalize the values, and clamp to finite half. + uint32_t acc_0 = satfinite_h2(hmul2(acc[mi][ni].reg(0), params_scale_bmm1_)); + uint32_t acc_1 = satfinite_h2(hmul2(acc[mi][ni].reg(1), params_scale_bmm1_)); + uint32_t acc_2 = satfinite_h2(hmul2(acc[mi][ni].reg(2), params_scale_bmm1_)); + uint32_t acc_3 = satfinite_h2(hmul2(acc[mi][ni].reg(3), params_scale_bmm1_)); + + // Extract the values as floats. + half2_to_float2(this->elt_[mi][8 * ni + 0], this->elt_[mi][8 * ni + 1], acc_0); + half2_to_float2(this->elt_[mi][8 * ni + 2], this->elt_[mi][8 * ni + 3], acc_1); + half2_to_float2(this->elt_[mi][8 * ni + 4], this->elt_[mi][8 * ni + 5], acc_2); + half2_to_float2(this->elt_[mi][8 * ni + 6], this->elt_[mi][8 * ni + 7], acc_3); + + if constexpr (ENABLE_BMM1_SOFTCAPPING_SCALE) { +#pragma unroll + for (int i = 0; i < 8; i++) { + // 1.0f / softcapping_scale has been fused to scale_bmm1. + this->elt_[mi][8 * ni + i] = + params_softcapping_scale_bmm1_ * __tanhf(this->elt_[mi][8 * ni + i]); + } + } + } + } + } + + // The scaling factor. + uint32_t const params_scale_bmm1_; + float const params_softcapping_scale_bmm1_; + // Shared memory for the CTA-wide reduction. + float *smem_, *smem_write_, *smem_read_; + // The current thread index. + int tidx_; + // The elements. + float elt_[MMAS_M][MMAS_N * 8]; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax + : public Softmax_hmma { + // The traits. + using Traits = fmha::Turing_hmma_fp16_traits; + // The base class. + using Base = Softmax_hmma; + // The fragment. + using Fragment_a = fmha::Fragment_a; + // Softmax dst data_type (BMM2 input) + using Dst_type = typename Traits::A_type; + + // Ctor. + template + inline __device__ Softmax(Params const& params, void* smem, int bidb, int tidx) + : Base(params, smem, bidb, tidx) {} + + // Pack the data to a fragment for the next GEMM. + template + inline __device__ void pack(Fragment_a (&dst)[K][M]) const { + static_assert(Base::Mma_tile::MMAS_M == M && Base::Mma_tile::MMAS_N * 4 == K * 2, ""); +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ki = 0; ki < K; ++ki) { + // 1st row - 2 elements per row. + float tmp_00 = this->elt_[2 * mi + 0][2 * ki + 0]; + float tmp_01 = this->elt_[2 * mi + 0][2 * ki + 1]; + + // 2nd row - 2 elements per row. + float tmp_10 = this->elt_[2 * mi + 1][2 * ki + 0]; + float tmp_11 = this->elt_[2 * mi + 1][2 * ki + 1]; + + // Pack to 2 registers. + dst[ki][mi].reg(0) = fmha::float2_to_16bit_2(tmp_00, tmp_01); + dst[ki][mi].reg(1) = fmha::float2_to_16bit_2(tmp_10, tmp_11); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax + : public Softmax_imma { + // The traits. + using Traits = fmha::Volta_imma_int8_int32_traits; + // The base class. + using Base = Softmax_imma; + // The fragment. + using Fragment_a = fmha::Fragment_a; + + // Ctor. + template + inline __device__ Softmax(Params const& params, void* smem, int bidb, int tidx) + : Base(params, smem, bidb, tidx) {} + + // Repack. We could use store/load to match the Smem_tile API. + template + inline __device__ void pack(Fragment_a (&dst)[K][M]) { + static_assert(Base::Mma_tile::MMAS_M == M && Base::Mma_tile::MMAS_N == K, ""); + float const scale = reinterpret_cast(this->params_scale_softmax_); +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ki = 0; ki < K; ++ki) { + // 1st row - 4 elements per row. + float tmp_00 = this->elt_[2 * mi + 0][4 * ki + 0] * scale; + float tmp_01 = this->elt_[2 * mi + 0][4 * ki + 1] * scale; + float tmp_02 = this->elt_[2 * mi + 0][4 * ki + 2] * scale; + float tmp_03 = this->elt_[2 * mi + 0][4 * ki + 3] * scale; + + // 2nd row - 4 elements per row. + float tmp_10 = this->elt_[2 * mi + 1][4 * ki + 0] * scale; + float tmp_11 = this->elt_[2 * mi + 1][4 * ki + 1] * scale; + float tmp_12 = this->elt_[2 * mi + 1][4 * ki + 2] * scale; + float tmp_13 = this->elt_[2 * mi + 1][4 * ki + 3] * scale; + + // Pack to 2 registers. + dst[ki][mi].reg(0) = float4_to_char4(tmp_00, tmp_01, tmp_02, tmp_03); + dst[ki][mi].reg(1) = float4_to_char4(tmp_10, tmp_11, tmp_12, tmp_13); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax + : public Softmax_imma { + // The traits. + using Traits = fmha::Turing_imma_int8_int32_traits; + // The base class. + using Base = Softmax_imma; + // The fragment. + using Fragment_a = fmha::Fragment_a; + + // Ctor. + template + inline __device__ Softmax(Params const& params, void* smem, int bidb, int tidx) + : Base(params, smem, bidb, tidx) {} + + // Repack. We could use store/load to match the Smem_tile API. + template + inline __device__ void pack(Fragment_a (&dst)[K][M]) { + static_assert(Base::Mma_tile::MMAS_M == M && Base::Mma_tile::MMAS_N == K, ""); + float const scale = reinterpret_cast(this->params_scale_softmax_); +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ki = 0; ki < K; ++ki) { + // 1st row - 4 elements per row. + float tmp_00 = this->elt_[2 * mi + 0][4 * ki + 0] * scale; + float tmp_01 = this->elt_[2 * mi + 0][4 * ki + 1] * scale; + float tmp_02 = this->elt_[2 * mi + 0][4 * ki + 2] * scale; + float tmp_03 = this->elt_[2 * mi + 0][4 * ki + 3] * scale; + + // 2nd row - 4 elements per row. + float tmp_10 = this->elt_[2 * mi + 1][4 * ki + 0] * scale; + float tmp_11 = this->elt_[2 * mi + 1][4 * ki + 1] * scale; + float tmp_12 = this->elt_[2 * mi + 1][4 * ki + 2] * scale; + float tmp_13 = this->elt_[2 * mi + 1][4 * ki + 3] * scale; + + // Pack to 2 registers. + dst[ki][mi].reg(0) = float4_to_char4(tmp_00, tmp_01, tmp_02, tmp_03); + dst[ki][mi].reg(1) = float4_to_char4(tmp_10, tmp_11, tmp_12, tmp_13); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax + : public Softmax_hmma { + // The traits. + using Traits = fmha::Ampere_hmma_fp16_traits; + // The base class. + using Base = Softmax_hmma; + // The fragment. + using Fragment_a = fmha::Fragment_a; + // Softmax dst data_type (BMM2 input) + using Dst_type = typename Traits::A_type; + + // Ctor. + template + inline __device__ Softmax(Params const& params, void* smem, int bidb, int tidx) + : Base(params, smem, bidb, tidx) {} + + // Pack the data to a fragment for the next GEMM. + template + inline __device__ void pack(Fragment_a (&dst)[K][M]) const { +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ki = 0; ki < K; ++ki) { + // 1st row - 4 elements per row. + float tmp_00 = this->elt_[2 * mi + 0][4 * ki + 0]; + float tmp_01 = this->elt_[2 * mi + 0][4 * ki + 1]; + float tmp_02 = this->elt_[2 * mi + 0][4 * ki + 2]; + float tmp_03 = this->elt_[2 * mi + 0][4 * ki + 3]; + + // 2nd row - 4 elements per row. + float tmp_10 = this->elt_[2 * mi + 1][4 * ki + 0]; + float tmp_11 = this->elt_[2 * mi + 1][4 * ki + 1]; + float tmp_12 = this->elt_[2 * mi + 1][4 * ki + 2]; + float tmp_13 = this->elt_[2 * mi + 1][4 * ki + 3]; + + // Pack to 4 registers. + dst[ki][mi].reg(0) = fmha::float2_to_16bit_2(tmp_00, tmp_01); + dst[ki][mi].reg(1) = fmha::float2_to_16bit_2(tmp_10, tmp_11); + dst[ki][mi].reg(2) = fmha::float2_to_16bit_2(tmp_02, tmp_03); + dst[ki][mi].reg(3) = fmha::float2_to_16bit_2(tmp_12, tmp_13); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax_fp32 : public Softmax_hmma { + // The base class. + using Base = Softmax_hmma; + // The fragment. + using Fragment_a = fmha::Fragment_a; + + // The MMAs. + enum { MMAS_M = Base::MMAS_M }; + + enum { MMAS_N = Base::MMAS_N }; + + // The accumulators. + using Accumulator = fmha::Fragment_accumulator; + // Output accumulators (after conversion). + using Accumulator_out = fmha::Fragment_accumulator; + + // Softmax dst data_type (BMM2 input) + using Dst_type = typename Traits::A_type; + + // DEBUG. + static_assert(Accumulator_out::NUM_REGS == 4, ""); + // END OF DEBUG. + + // DEBUG. + static_assert(std::is_same::value, ""); + + // END OF DEBUG. + + enum { WARPS_M = Cta_tile::WARPS_M }; + + enum { WARPS_N = Cta_tile::WARPS_N }; + + // Use BMM1 softcapping scale or not. + enum { ENABLE_BMM1_SOFTCAPPING_SCALE = Kernel_traits::ENABLE_BMM1_SOFTCAPPING_SCALE }; + + using Smem_tile_red = Smem_tile_reduce; + static_assert(Smem_tile_red::ELTS_PER_TILE == Cta_tile::M * WARPS_N); + + // Ctor. + template + inline __device__ Softmax_fp32(Params const& params, void* smem, int bidb, int tidx) + : Base(params, smem, bidb, tidx), + smem_sum_(static_cast(smem), tidx), + smem_max_(static_cast(smem) + Smem_tile_red::ELTS_PER_TILE, tidx) {} + + // Store the tile after softmax. + template + inline __device__ void store(Gmem_tile& gmem_tile) { + Accumulator_out acc[MMAS_M][MMAS_N]; +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + // The elements. + float tmp_00 = this->elt_[2 * mi + 0][4 * ni + 0]; + float tmp_01 = this->elt_[2 * mi + 0][4 * ni + 1]; + float tmp_02 = this->elt_[2 * mi + 0][4 * ni + 2]; + float tmp_03 = this->elt_[2 * mi + 0][4 * ni + 3]; + float tmp_10 = this->elt_[2 * mi + 1][4 * ni + 0]; + float tmp_11 = this->elt_[2 * mi + 1][4 * ni + 1]; + float tmp_12 = this->elt_[2 * mi + 1][4 * ni + 2]; + float tmp_13 = this->elt_[2 * mi + 1][4 * ni + 3]; + + // Transform to accumulators. + acc[mi][ni].reg(0) = fmha::float2_to_16bit_2(tmp_00, tmp_01); + acc[mi][ni].reg(1) = fmha::float2_to_16bit_2(tmp_10, tmp_11); + acc[mi][ni].reg(2) = fmha::float2_to_16bit_2(tmp_02, tmp_03); + acc[mi][ni].reg(3) = fmha::float2_to_16bit_2(tmp_12, tmp_13); + } + } + + // Delegate to the gmem tile to store. + gmem_tile.store(acc); + } + + // Pack the data to a fragment for the next GEMM. + template + inline __device__ void pack(Fragment_a (&dst)[K][M]) const { + static_assert(Fragment_a::NUM_REGS == 4, ""); +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ki = 0; ki < K; ++ki) { + // 1st row - 4 elements per row. + float tmp_00 = this->elt_[2 * mi + 0][4 * ki + 0]; + float tmp_01 = this->elt_[2 * mi + 0][4 * ki + 1]; + float tmp_02 = this->elt_[2 * mi + 0][4 * ki + 2]; + float tmp_03 = this->elt_[2 * mi + 0][4 * ki + 3]; + + // 2nd row - 4 elements per row. + float tmp_10 = this->elt_[2 * mi + 1][4 * ki + 0]; + float tmp_11 = this->elt_[2 * mi + 1][4 * ki + 1]; + float tmp_12 = this->elt_[2 * mi + 1][4 * ki + 2]; + float tmp_13 = this->elt_[2 * mi + 1][4 * ki + 3]; + + // Pack to 4 registers. + dst[ki][mi].reg(0) = fmha::float2_to_16bit_2(tmp_00, tmp_01); + dst[ki][mi].reg(1) = fmha::float2_to_16bit_2(tmp_10, tmp_11); + dst[ki][mi].reg(2) = fmha::float2_to_16bit_2(tmp_02, tmp_03); + dst[ki][mi].reg(3) = fmha::float2_to_16bit_2(tmp_12, tmp_13); + } + } + } + + // Pack the data to a uint4 for the next operation. + template + inline __device__ void pack(uint4 (&dst)[M][N]) const { +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ni = 0; ni < N; ++ni) { + // 1st row - 4 elements per row. + float tmp_00 = this->elt_[2 * mi + 0][4 * ni + 0]; + float tmp_01 = this->elt_[2 * mi + 0][4 * ni + 1]; + float tmp_02 = this->elt_[2 * mi + 0][4 * ni + 2]; + float tmp_03 = this->elt_[2 * mi + 0][4 * ni + 3]; + + // 2nd row - 4 elements per row. + float tmp_10 = this->elt_[2 * mi + 1][4 * ni + 0]; + float tmp_11 = this->elt_[2 * mi + 1][4 * ni + 1]; + float tmp_12 = this->elt_[2 * mi + 1][4 * ni + 2]; + float tmp_13 = this->elt_[2 * mi + 1][4 * ni + 3]; + + // Pack to 4 registers. + dst[mi][ni].x = fmha::float2_to_16bit_2(tmp_00, tmp_01); + dst[mi][ni].y = fmha::float2_to_16bit_2(tmp_02, tmp_03); + dst[mi][ni].z = fmha::float2_to_16bit_2(tmp_10, tmp_11); + dst[mi][ni].w = fmha::float2_to_16bit_2(tmp_12, tmp_13); + } + } + } + + // Scale FP32 fragments + inline __device__ void unpack(Accumulator const (&acc)[MMAS_M][MMAS_N]) { + float const scalef = reinterpret_cast(this->params_scale_bmm1_); + +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + // 1st row - 4 elements per row. + this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0) * scalef; + this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1) * scalef; + this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi][ni].elt(4) * scalef; + this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi][ni].elt(5) * scalef; + // 2nd row - 4 elements per row. + this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi][ni].elt(2) * scalef; + this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3) * scalef; + this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6) * scalef; + this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7) * scalef; + + // Attention logit softcapping scale. + // 1.0f / softcapping_scale has been fused to scale_bmm1. + if constexpr (ENABLE_BMM1_SOFTCAPPING_SCALE) { + this->elt_[2 * mi + 0][4 * ni + 0] = + this->params_softcapping_scale_bmm1_ * __tanhf(this->elt_[2 * mi + 0][4 * ni + 0]); + this->elt_[2 * mi + 0][4 * ni + 1] = + this->params_softcapping_scale_bmm1_ * __tanhf(this->elt_[2 * mi + 0][4 * ni + 1]); + this->elt_[2 * mi + 1][4 * ni + 0] = + this->params_softcapping_scale_bmm1_ * __tanhf(this->elt_[2 * mi + 1][4 * ni + 0]); + this->elt_[2 * mi + 1][4 * ni + 1] = + this->params_softcapping_scale_bmm1_ * __tanhf(this->elt_[2 * mi + 1][4 * ni + 1]); + this->elt_[2 * mi + 0][4 * ni + 2] = + this->params_softcapping_scale_bmm1_ * __tanhf(this->elt_[2 * mi + 0][4 * ni + 2]); + this->elt_[2 * mi + 0][4 * ni + 3] = + this->params_softcapping_scale_bmm1_ * __tanhf(this->elt_[2 * mi + 0][4 * ni + 3]); + this->elt_[2 * mi + 1][4 * ni + 2] = + this->params_softcapping_scale_bmm1_ * __tanhf(this->elt_[2 * mi + 1][4 * ni + 2]); + this->elt_[2 * mi + 1][4 * ni + 3] = + this->params_softcapping_scale_bmm1_ * __tanhf(this->elt_[2 * mi + 1][4 * ni + 3]); + } + } + } + } + + // Scale FP32 fragments + inline __device__ void unpack_noscale(Accumulator const (&acc)[MMAS_M][MMAS_N]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + // 1st row - 4 elements per row. + this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0); + this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1); + this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi][ni].elt(4); + this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi][ni].elt(5); + // 2nd row - 4 elements per row. + this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi][ni].elt(2); + this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3); + this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6); + this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7); + } + } + } + + template + __device__ inline void reduce_(float (&frag)[2 * MMAS_M], Operator& op, Smem_tile_red& smem_red) { +#pragma unroll + for (int mi = 0; mi < 2 * MMAS_M; mi++) { + frag[mi] = this->elt_[mi][0]; +#pragma unroll + for (int ni = 1; ni < 4 * MMAS_N; ni++) { + frag[mi] = op(frag[mi], this->elt_[mi][ni]); + } + } + quad_reduce(frag, frag, op); + + if (WARPS_N > 1) { + smem_red.store(frag); + __syncthreads(); + typename Smem_tile_red::read_t tmp[2 * MMAS_M]; + smem_red.load(tmp); + + quad_allreduce(frag, tmp, op); + } + } + + __device__ inline void reduce_max(float (&frag)[2 * MMAS_M]) { + MaxOp max; + reduce_(frag, max, smem_max_); + } + + __device__ inline void reduce_sum(float (&frag)[2 * MMAS_M]) { + SumOp sum; + reduce_(frag, sum, smem_sum_); + } + + __device__ inline float correct(float warp_sum, float warp_max, float max) { + return warp_sum * __expf(warp_max - max); + } + + __device__ inline float2 correct(float2 warp_sum, float2 warp_max, float max) { + return {correct(warp_sum.x, warp_max.x, max), correct(warp_sum.y, warp_max.y, max)}; + } + + __device__ inline void online_softmax() { + MaxOp maxOp; + SumOp sumOp; + float max[2 * MMAS_M]; +#pragma unroll + for (int mi = 0; mi < 2 * MMAS_M; mi++) { + max[mi] = this->elt_[mi][0]; +#pragma unroll + for (int ni = 1; ni < 4 * MMAS_N; ni++) { + max[mi] = maxOp(max[mi], this->elt_[mi][ni]); + } + } + quad_allreduce(max, max, maxOp); + smem_max_.store(max); + float sum[2 * MMAS_M]; +#pragma unroll + for (int mi = 0; mi < 2 * MMAS_M; mi++) { + sum[mi] = 0.f; +#pragma unroll + for (int ni = 0; ni < 4 * MMAS_N; ni++) { + float x = this->elt_[mi][ni]; + this->elt_[mi][ni] = __expf(x - max[mi]); + sum[mi] += this->elt_[mi][ni]; + } + } + quad_allreduce(sum, sum, sumOp); + smem_sum_.store(sum); + + __syncthreads(); + + typename Smem_tile_red::read_t tmp_max[2 * MMAS_M]; + typename Smem_tile_red::read_t tmp_sum[2 * MMAS_M]; + smem_max_.load(tmp_max); + smem_sum_.load(tmp_sum); + float full_max[2 * MMAS_M]; + quad_allreduce(full_max, tmp_max, maxOp); +#pragma unroll + for (int mi = 0; mi < 2 * MMAS_M; mi++) { + tmp_sum[mi] = correct(tmp_sum[mi], tmp_max[mi], full_max[mi]); + } + quad_allreduce(sum, tmp_sum, sumOp); +#pragma unroll + for (int mi = 0; mi < 2 * MMAS_M; mi++) { + float correction = __expf(max[mi] - full_max[mi]) / sum[mi]; +#pragma unroll + for (int ni = 0; ni < 4 * MMAS_N; ni++) { + this->elt_[mi][ni] *= correction; + } + } + } + + Smem_tile_red smem_max_; + Smem_tile_red smem_sum_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax + : public Softmax_fp32 { + // The traits. + using Traits = fmha::Ampere_hmma_fp32_traits; + // The base class. + using Base = Softmax_fp32; + + // Ctor. + template + inline __device__ Softmax(Params const& params, void* smem, int bidb, int tidx) + : Base(params, smem, bidb, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax + : public Softmax_fp32 { + // The traits. + using Traits = fmha::Turing_hmma_fp32_traits; + // The base class. + using Base = Softmax_fp32; + // The fragment. + using Fragment_a = fmha::Fragment_a; + // Softmax dst data_type (BMM2 input) + using Dst_type = typename Traits::A_type; + + // Ctor. + template + inline __device__ Softmax(Params const& params, void* smem, int bidb, int tidx) + : Base(params, smem, bidb, tidx) {} + + // Pack the data to a fragment for the next GEMM. + template + inline __device__ void pack(Fragment_a (&dst)[K][M]) const { + static_assert(Fragment_a::NUM_REGS == 2, ""); + static_assert(Base::Mma_tile::MMAS_M == M && Base::Mma_tile::MMAS_N * 4 == K * 2, ""); +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ki = 0; ki < K; ++ki) { + // 1st row - 2 elements per row. + float tmp_00 = this->elt_[2 * mi + 0][2 * ki + 0]; + float tmp_01 = this->elt_[2 * mi + 0][2 * ki + 1]; + + // 2nd row - 2 elements per row. + float tmp_10 = this->elt_[2 * mi + 1][2 * ki + 0]; + float tmp_11 = this->elt_[2 * mi + 1][2 * ki + 1]; + + // Pack to 2 registers. + dst[ki][mi].reg(0) = fmha::float2_to_16bit_2(tmp_00, tmp_01); + dst[ki][mi].reg(1) = fmha::float2_to_16bit_2(tmp_10, tmp_11); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax + : public Softmax_fp32 { + // The traits. + using Traits = fmha::Ampere_hmma_bf16_traits; + // The base class. + using Base = Softmax_fp32; + + // Ctor. + template + inline __device__ Softmax(Params const& params, void* smem, int bidb, int tidx) + : Base(params, smem, bidb, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax + : public Softmax_imma { + // The traits. + using Traits = fmha::Ampere_imma_int8_int32_traits; + // The base class. + using Base = Softmax_imma; + + // Ctor. + template + inline __device__ Softmax(Params const& params, void* smem, int bidb, int tidx) + : Base(params, smem, bidb, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax + : public Softmax_qmma { + // The traits. + using Traits = fmha::Ada_qmma_e4m3_fp32_traits; + // The base class. + using Base = Softmax_qmma; + + // Ctor. + template + inline __device__ Softmax(Params const& params, void* smem, int bidb, int tidx) + : Base(params, smem, bidb, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax + : public Softmax_qmma { + // The traits. + using Traits = fmha::Ada_qmma_e4m3_fp16_traits; + // The base class. + using Base = Softmax_qmma; + + // Ctor. + template + inline __device__ Softmax(Params const& params, void* smem, int bidb, int tidx) + : Base(params, smem, bidb, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax + : public Softmax_imma { + // The Traits + using Traits = fmha::Ada_qmma_e4m3_fp32_traits; + // The base class. + using Base = Softmax_imma; + + // The MMAs. + enum { MMAS_M = Base::MMAS_M }; + + enum { MMAS_N = Base::MMAS_N }; + + // The accumulators. + using Accumulator = fmha::Fragment_accumulator; + + // Ctor. + template + inline __device__ Softmax(Params const& params, void* smem, int bidb, int tidx) + : Base(params, smem, bidb, tidx), + params_scale_bmm1_(params.scale_bmm1_d ? *params.scale_bmm1_d : params.scale_bmm1), + params_scale_softmax_(params.scale_softmax) {} + + // Store the tile after softmax. + template + inline __device__ void store(Gmem_tile& gmem_tile) { + float const scale = reinterpret_cast(params_scale_softmax_); + Accumulator acc[MMAS_M][MMAS_N]; +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + // scale + acc[mi][ni].ele(0) = this->elt_[2 * mi + 0][4 * ni + 0] * scale; + acc[mi][ni].ele(1) = this->elt_[2 * mi + 0][4 * ni + 1] * scale; + acc[mi][ni].ele(4) = this->elt_[2 * mi + 0][4 * ni + 2] * scale; + acc[mi][ni].ele(5) = this->elt_[2 * mi + 0][4 * ni + 3] * scale; + acc[mi][ni].ele(2) = this->elt_[2 * mi + 1][4 * ni + 0] * scale; + acc[mi][ni].ele(3) = this->elt_[2 * mi + 1][4 * ni + 1] * scale; + acc[mi][ni].ele(6) = this->elt_[2 * mi + 1][4 * ni + 2] * scale; + acc[mi][ni].ele(7) = this->elt_[2 * mi + 1][4 * ni + 3] * scale; + } + } + + // Delegate to the gmem tile to store. + // TODO: need fp32 to fp8 conversion (move this to gmem_tile) + gmem_tile.store(acc); + } + + // Convert from accumulators to floats. + inline __device__ void unpack(Accumulator const (&acc)[MMAS_M][MMAS_N]) { + float const scale = params_scale_q_ * params_scale_k_; +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + // Convert to FP32 and scale. + this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0) * scale; + this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1) * scale; + this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi][ni].elt(4) * scale; + this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi][ni].elt(5) * scale; + this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi][ni].elt(2) * scale; + this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3) * scale; + this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6) * scale; + this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7) * scale; + } + } + } + + template + inline __device__ void apply_exp_with_mask(float const (&max)[MMAS_M * 2]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { + float max_val = APPLY_MASK && max[mi] == -FLT_MAX + ? 0.f + : (max[mi] - logf(Traits::SOFTMAX_FP_QUANT_SCALE)); +#pragma unroll + for (int ni = 0; ni < MMAS_N * 4; ++ni) { + this->elt_[mi][ni] = expf(this->elt_[mi][ni] - max_val); + } + } + } + + // Pack the data to a fragment for the next GEMM. + template + inline __device__ void pack(Fragment_a (&dst)[K][M]) const { + float const scale = reinterpret_cast(this->params_scale_softmax_); + +// The canonical layout in K should be R0: [0,1,2,3] R2: [16,17,18,19] +// Note below that this is not possible with the register layout of the accumulator. +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ki = 0; ki < K; ++ki) { + // 1st row - 8 elements per row. + float tmp_00 = this->elt_[2 * mi + 0][8 * ki + 0] * scale; // + 0 + float tmp_01 = this->elt_[2 * mi + 0][8 * ki + 1] * scale; // + 1 + float tmp_02 = this->elt_[2 * mi + 0][8 * ki + 2] * scale; // + 8 + float tmp_03 = this->elt_[2 * mi + 0][8 * ki + 3] * scale; // + 9 + float tmp_04 = this->elt_[2 * mi + 0][8 * ki + 4] * scale; // +16 + float tmp_05 = this->elt_[2 * mi + 0][8 * ki + 5] * scale; // +17 + float tmp_06 = this->elt_[2 * mi + 0][8 * ki + 6] * scale; // +24 + float tmp_07 = this->elt_[2 * mi + 0][8 * ki + 7] * scale; // +25 + + // 2nd row - 4 elements per row. + float tmp_10 = this->elt_[2 * mi + 1][8 * ki + 0] * scale; // + 0 + float tmp_11 = this->elt_[2 * mi + 1][8 * ki + 1] * scale; // + 1 + float tmp_12 = this->elt_[2 * mi + 1][8 * ki + 2] * scale; // + 8 + float tmp_13 = this->elt_[2 * mi + 1][8 * ki + 3] * scale; // + 9 + float tmp_14 = this->elt_[2 * mi + 1][8 * ki + 4] * scale; // +16 + float tmp_15 = this->elt_[2 * mi + 1][8 * ki + 5] * scale; // +17 + float tmp_16 = this->elt_[2 * mi + 1][8 * ki + 6] * scale; // +24 + float tmp_17 = this->elt_[2 * mi + 1][8 * ki + 7] * scale; // +25 + + // Pack to 4 registers. + dst[ki][mi].reg(0) = fmha::float4_to_fp8x4(tmp_00, tmp_01, tmp_02, tmp_03); + dst[ki][mi].reg(1) = fmha::float4_to_fp8x4(tmp_10, tmp_11, tmp_12, tmp_13); + dst[ki][mi].reg(2) = fmha::float4_to_fp8x4(tmp_04, tmp_05, tmp_06, tmp_07); + dst[ki][mi].reg(3) = fmha::float4_to_fp8x4(tmp_14, tmp_15, tmp_16, tmp_17); + } + } + } + + template + inline __device__ void move_to_first_block(Params const& params, int bidb, int bidh, int q_loop) { + int scale_q_iter = + bidb * params.h * params.sage.q.max_nblock + bidh * params.sage.q.max_nblock + q_loop; + params_scale_q_ = __ldg(params.sage.q.scales + scale_q_iter); + params_scale_q_ *= reinterpret_cast(params_scale_bmm1_); + + int scale_k_iter = bidb * params.h * params.sage.k.max_nblock + bidh * params.sage.k.max_nblock; + params_scale_k_iter = reinterpret_cast(params.sage.k.scales + scale_k_iter); + params_scale_k_ = __ldg(params_scale_k_iter); + } + + inline __device__ void move_to_next_block() { + params_scale_k_iter += 1; + params_scale_k_ = __ldg(params_scale_k_iter); + } + + // The scaling factors. + uint32_t const params_scale_bmm1_, params_scale_softmax_; + float params_scale_q_, params_scale_k_; + float const* params_scale_k_iter; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// HOPPER SOFTMAX + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax_gmma_base {}; + +template +struct Softmax_gmma_base { + // The instruction traits. + using Traits = Traits_; + // The Cta_tile. + using Cta_tile = Cta_tile_; + // The Kernel traits. + using Kernel_traits = Kernel_traits_; + // The accumulators. + using Accumulator = fmha::Fragment_accumulator; + // The Mma tile. + using Mma_tile = typename Traits::template Mma_tile; + + static_assert(Cta_tile::WARPS_M == 4); + static_assert(Mma_tile::M_PER_MMA_PER_CTA == 64); + + // The number of MMAs in M/N dimensions. + enum { MMAS_M = Mma_tile::MMAS_M }; + + enum { MMAS_N = Mma_tile::MMAS_N }; + + // Elements per thread per core matrix. + enum { ELTS_PER_THREAD = 2 }; + + // Core matrix is always 8x4. + enum { THREADS_PER_ROW = 4 }; + + enum { SMEM_BYTES = 0 }; + + // The number of rows accessed by each thread. + enum { + ROWS_PER_THREAD = + Traits::GMMA_M / (Cta_tile::THREADS_PER_WARP / THREADS_PER_ROW) / Cta_tile::WARPS_M + }; + + static_assert(ROWS_PER_THREAD == Mma_tile::ROWS_PER_THREAD); + + // The number of columns access by each thread. + // Note there are 2 elements per reg. + enum { COLS_PER_THREAD = Traits::GMMA_N / THREADS_PER_ROW / ELTS_PER_THREAD }; + + // The number of total elements per thread. + enum { TOTAL_ELTS_PER_THREAD = ELTS_PER_THREAD * COLS_PER_THREAD }; + + template + inline __device__ Softmax_gmma_base(Params const& params, void*, int const, int const) + : params_scale_bmm1_(params.scale_bmm1), + params_softcapping_scale_bmm1_(params.softcapping_scale_bmm1) {} + + // Apply mask before softmax. Use 1 byte per MMA distributed as 2x4. + template + inline __device__ void apply_mask(Mask const& mask) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ii = 0; ii < ROWS_PER_THREAD; ++ii) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int jj = 0; jj < TOTAL_ELTS_PER_THREAD; ++jj) { + if (!mask.is_valid(mi, ni, ii, jj)) { + this->elt_[ROWS_PER_THREAD * mi + ii][TOTAL_ELTS_PER_THREAD * ni + jj] = -FLT_MAX; + } + } // jj + } // ni + } // ii + } // mi + } + + template + inline __device__ void apply_mask_alibi(Mask const& mask, int head_id, + AlibiParams const& alibi_params) { + // 'if constexpr' because ALiBi is only defined for causal masks + if constexpr (Kernel_traits::CAUSAL_MASK) { + float m = get_alibi_head_scaling_factor(head_id, alibi_params); +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ii = 0; ii < ROWS_PER_THREAD; ++ii) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int jj = 0; jj < TOTAL_ELTS_PER_THREAD; ++jj) { + int row, col; + mask.get_row_col(row, col, mi, ni, ii, jj); + if (mask.is_valid(row, col)) { + // Since softmax is shift invariant, + // it is sufficient just to use the column as the multiplier + elt_[ROWS_PER_THREAD * mi + ii][TOTAL_ELTS_PER_THREAD * ni + jj] = + elt_[ROWS_PER_THREAD * mi + ii][TOTAL_ELTS_PER_THREAD * ni + jj] * + alibi_params.scale_after_alibi + + m * (col + alibi_params.sequence_pos_offset); + } else { + elt_[ROWS_PER_THREAD * mi + ii][TOTAL_ELTS_PER_THREAD * ni + jj] = -FLT_MAX; + } + } + } + } + } + } else { + __builtin_unreachable(); + } + } + + // Do a CTA-wide reduction. + template + inline __device__ void reduce_4x1(float (&dst)[MMAS_M * ROWS_PER_THREAD]) { +#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) + static_assert(MMAS_N * COLS_PER_THREAD * ELTS_PER_THREAD == MMAS_N * Mma_tile::CORES_N * 2); + if (Functor::IS_SUM) { +// Apply the summation inside the thread. +#pragma unroll + for (int mi = 0; mi < MMAS_M * ROWS_PER_THREAD; ++mi) { + dst[mi] = (this->elt_[mi][0] + this->elt_[mi][1]); +#pragma unroll + for (int ni = 1; ni < MMAS_N * Mma_tile::CORES_N; ni++) { + dst[mi] += (this->elt_[mi][ni * 2 + 0] + this->elt_[mi][ni * 2 + 1]); + } + } + } else +#endif // defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) + { +// find the max/sum for each row. +// For hopper, each row is held entirely within 4 threads. +// Apply the functor for each row inside a thread. +#pragma unroll + for (int mi = 0; mi < MMAS_M * ROWS_PER_THREAD; ++mi) { + dst[mi] = this->elt_[mi][0]; +#pragma unroll + for (int ni = 1; ni < MMAS_N * COLS_PER_THREAD * ELTS_PER_THREAD; ++ni) { + dst[mi] = Functor::apply(dst[mi], this->elt_[mi][ni]); + } + } + } +// Apply the functor for each row inside each group of 4 threads. +#pragma unroll + for (int mi = 0; mi < MMAS_M * ROWS_PER_THREAD; ++mi) { + dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 1)); + __syncwarp(); + dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 2)); + __syncwarp(); + } + } + + // Do a CTA-wide reduction. + template + inline __device__ void reduce(float (&dst)[MMAS_M * ROWS_PER_THREAD]) { + reduce_4x1(dst); + } + + // Apply the exp to all the elements. + inline __device__ void apply_exp(float const (&max)[MMAS_M * ROWS_PER_THREAD]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M * ROWS_PER_THREAD; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N * COLS_PER_THREAD * ELTS_PER_THREAD; ++ni) { + this->elt_[mi][ni] = apply_exp_(this->elt_[mi][ni], max[mi]); + } + } + } + + // Scale all the elements. + inline __device__ void scale(float const (&sum)[MMAS_M * ROWS_PER_THREAD]) { + // Precompute the inverse sum to normalize. Without -use_fast_math, it makes a huge deal. + float inv_sum[MMAS_M * ROWS_PER_THREAD]; +#pragma unroll + for (int mi = 0; mi < MMAS_M * ROWS_PER_THREAD; ++mi) { + inv_sum[mi] = (sum[mi] == 0.f || sum[mi] != sum[mi]) ? 1.f : 1.f / sum[mi]; + } + +// Update the values. +#pragma unroll + for (int mi = 0; mi < MMAS_M * ROWS_PER_THREAD; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N * COLS_PER_THREAD * ELTS_PER_THREAD; ++ni) { + this->elt_[mi][ni] *= inv_sum[mi]; + } + } + } + + // The scalig factor. Depens on acc type, e.g. float for 32-bit and fp16x2/bf16x2 for 16-bit. + uint32_t const params_scale_bmm1_; + float const params_softcapping_scale_bmm1_; + // The elements. + float elt_[MMAS_M * ROWS_PER_THREAD][MMAS_N * COLS_PER_THREAD * ELTS_PER_THREAD]; +}; + +template +struct Softmax_gmma_base + : public Softmax_gmma_base { + using Base = Softmax_gmma_base; + + using Mma_tile = typename Base::Mma_tile; + + enum { BYTES_PER_SMEM = Mma_tile::M_PER_MMA_PER_CTA * Cta_tile::WARPS_N * sizeof(float) }; + + enum { ELTS_PER_ROW = 2 }; + + static_assert(Cta_tile::WARPS_N == 2); + static_assert(Cta_tile::WARPS_M == 4); + static_assert(Mma_tile::M_PER_MMA_PER_CTA == 64); + + template + inline __device__ Softmax_gmma_base(Params const& params, void* smem, int const bidb, + int const tidx) + : Base(params, smem, bidb, tidx) { + int const warp = tidx / Cta_tile::THREADS_PER_WARP; + int const warp_n = warp / 4; + int const warp_m = warp % 4; + int const lane = tidx % Cta_tile::THREADS_PER_WARP; + int const quad = lane / 4; + is_writer_ = lane % 4 == 0; + + int const col = warp_n; + int const row = warp_m * 16 + quad; + + smem_write_ = static_cast(smem) + row * 2 + col; + smem_read_ = static_cast(smem) + row; + } + + // Do a CTA-wide reduction. + template + inline __device__ void reduce(float (&dst)[2]) { + Base::template reduce_4x1(dst); + if (is_writer_) { + smem_write_[0 * ELTS_PER_ROW] = dst[0]; + smem_write_[8 * ELTS_PER_ROW] = dst[1]; + } + __syncthreads(); + float2 tmp0 = smem_read_[0]; + float2 tmp1 = smem_read_[8]; + dst[0] = Functor::apply(tmp0.x, tmp0.y); + dst[1] = Functor::apply(tmp1.x, tmp1.y); + } + + float* smem_write_; + float2* smem_read_; + bool is_writer_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax, + Cta_tile_, Kernel_traits_> + : public Softmax_gmma_base< + fmha::Hopper_hgmma_fp16_traits, Cta_tile_, + Kernel_traits_, Cta_tile_::WARPS_N> { + // The traits. + using Traits = fmha::Hopper_hgmma_fp16_traits; + // Cta_tile. + using Cta_tile = Cta_tile_; + // Kernel_traits. + using Kernel_traits = Kernel_traits_; + // The Base class. + using Base = Softmax_gmma_base; + // The accumulators. + using Accumulator = typename Base::Accumulator; + // The Mma tile. + using Mma_tile = typename Base::Mma_tile; + + // The number of MMAs in M/N dimensions. + enum { MMAS_M = Mma_tile::MMAS_M }; + + enum { MMAS_N = Mma_tile::MMAS_N }; + + // for HGMMA_FP16, there are 2 elements per RF for ACC. + enum { ELTS_PER_THREAD = 2 }; + + // for Hopper HGMMA, each row is held within 4 threads. + enum { THREADS_PER_ROW = 4 }; + + // The number of rows accessed by each thread. + enum { + ROWS_PER_THREAD = + Traits::GMMA_M / (Cta_tile::THREADS_PER_WARP / THREADS_PER_ROW) / Cta_tile::WARPS_M + }; + + // The number of columns access by each thread. + // Note there are 2 elements per reg. + enum { COLS_PER_THREAD = Traits::GMMA_N / THREADS_PER_ROW / ELTS_PER_THREAD }; + + // Use BMM1 softcapping scale or not. + enum { ENABLE_BMM1_SOFTCAPPING_SCALE = Kernel_traits::ENABLE_BMM1_SOFTCAPPING_SCALE }; + + // Ctor. + template + inline __device__ Softmax(Params const& params, void* smem, int bidb, int tidx) + : Base(params, smem, bidb, tidx) {} + + // Convert from FP16 fragments to floats. + inline __device__ void unpack(Accumulator const (&acc)[MMAS_M][MMAS_N]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int col_idx = 0; col_idx < COLS_PER_THREAD; ++col_idx) { +#pragma unroll + for (int row_idx = 0; row_idx < ROWS_PER_THREAD; ++row_idx) { + // the order of the acc rf is we traverse vertically first + // then we traverse horizontally. + + // Normalize the values. + uint32_t acc_0 = fmha::hmul2(acc[mi][ni].reg(col_idx * ROWS_PER_THREAD + row_idx), + this->params_scale_bmm1_); + // Element index. + int elt_row_idx = ROWS_PER_THREAD * mi + row_idx; + int elt_col_idx = COLS_PER_THREAD * ELTS_PER_THREAD * ni + col_idx * ELTS_PER_THREAD; + // Extract the values as floats. + half2_to_float2(this->elt_[elt_row_idx][elt_col_idx + 0], + this->elt_[elt_row_idx][elt_col_idx + 1], acc_0); + // Attention logit softcapping scale. + // 1.0f / softcapping_scale has been fused to scale_bmm1. + if constexpr (ENABLE_BMM1_SOFTCAPPING_SCALE) { + this->elt_[elt_row_idx][elt_col_idx + 0] = + this->params_softcapping_scale_bmm1_ * + __tanhf(this->elt_[elt_row_idx][elt_col_idx + 0]); + this->elt_[elt_row_idx][elt_col_idx + 1] = + this->params_softcapping_scale_bmm1_ * + __tanhf(this->elt_[elt_row_idx][elt_col_idx + 1]); + } + } // row_idx + } // col_idx + } // ni + } // mi + } + + // Store the tile after softmax. + template + inline __device__ void store(Gmem_tile& gmem_tile) { + Accumulator acc[MMAS_M][MMAS_N]; + +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int col_idx = 0; col_idx < COLS_PER_THREAD; ++col_idx) { +#pragma unroll + for (int row_idx = 0; row_idx < ROWS_PER_THREAD; ++row_idx) { + // the order of the acc rf is we traverse vertically first + // then we traverse horizontally. + float tmp_00 = + this->elt_[ROWS_PER_THREAD * mi + row_idx] + [COLS_PER_THREAD * ELTS_PER_THREAD * ni + col_idx * ELTS_PER_THREAD + 0]; + float tmp_01 = + this->elt_[ROWS_PER_THREAD * mi + row_idx] + [COLS_PER_THREAD * ELTS_PER_THREAD * ni + col_idx * ELTS_PER_THREAD + 1]; + acc[mi][ni].reg(col_idx * ROWS_PER_THREAD + row_idx) = + fmha::float2_to_half2(tmp_00, tmp_01); + } // row_idx + } // col_idx + } // ni + } // m + + // Delegate to the gmem tile to store. + gmem_tile.store(acc); + } + + // Pack the data to a fragment for the next GEMM. + template + inline __device__ void pack(Fragment_a (&dst)[K][M]) const { +// we know the instruction shape is 64xNx16 +// Thus for input A matrix, it is of size 64x16 per warpgroup. +// Thus, each threads access 2 rows and 4 columns. contiguous 2 columns are held by 1 RF. +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ki = 0; ki < K; ++ki) { + // 1st row - 4 elements per row. + float tmp_00 = this->elt_[2 * mi + 0][4 * ki + 0]; + float tmp_01 = this->elt_[2 * mi + 0][4 * ki + 1]; + float tmp_02 = this->elt_[2 * mi + 0][4 * ki + 2]; + float tmp_03 = this->elt_[2 * mi + 0][4 * ki + 3]; + + // 2nd row - 4 elements per row. + float tmp_10 = this->elt_[2 * mi + 1][4 * ki + 0]; + float tmp_11 = this->elt_[2 * mi + 1][4 * ki + 1]; + float tmp_12 = this->elt_[2 * mi + 1][4 * ki + 2]; + float tmp_13 = this->elt_[2 * mi + 1][4 * ki + 3]; + + // Pack to 4 registers. + dst[ki][mi].reg(0) = fmha::float2_to_half2(tmp_00, tmp_01); + dst[ki][mi].reg(1) = fmha::float2_to_half2(tmp_10, tmp_11); + dst[ki][mi].reg(2) = fmha::float2_to_half2(tmp_02, tmp_03); + dst[ki][mi].reg(3) = fmha::float2_to_half2(tmp_12, tmp_13); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax, + Cta_tile_, Kernel_traits_> + : public Softmax_gmma_base< + fmha::Hopper_hgmma_fp32_traits, Cta_tile_, + Kernel_traits_, Cta_tile_::WARPS_N> { + // The traits. + using Traits = fmha::Hopper_hgmma_fp32_traits; + // Cta_tile. + using Cta_tile = Cta_tile_; + // Kernel_traits. + using Kernel_traits = Kernel_traits_; + // The Base class. + using Base = Softmax_gmma_base; + // The accumulators. + using Accumulator = typename Base::Accumulator; + // The Mma tile. + using Mma_tile = typename Base::Mma_tile; + + // The number of MMAs in M/N dimensions. + enum { MMAS_M = Mma_tile::MMAS_M }; + + enum { MMAS_N = Mma_tile::MMAS_N }; + + // for HGMMA_FP16, there are 2 elements per RF for ACC. + enum { ELTS_PER_THREAD = 2 }; + + // for Hopper HGMMA, each row is held within 4 threads. + enum { THREADS_PER_ROW = 4 }; + + // The number of rows accessed by each thread. + enum { + ROWS_PER_THREAD = + Traits::GMMA_M / (Cta_tile::THREADS_PER_WARP / THREADS_PER_ROW) / Cta_tile::WARPS_M + }; + + // The number of columns access by each thread. + // Note there are 2 elements per reg. + enum { COLS_PER_THREAD = Traits::GMMA_N / THREADS_PER_ROW / ELTS_PER_THREAD }; + + // Use BMM1 softcapping scale or not. + enum { ENABLE_BMM1_SOFTCAPPING_SCALE = Kernel_traits::ENABLE_BMM1_SOFTCAPPING_SCALE }; + + // Ctor. + template + inline __device__ Softmax(Params const& params, void* smem, int bidb, int tidx) + : Base(params, smem, bidb, tidx) {} + + // Convert from FP16 fragments to floats. + inline __device__ void unpack(Accumulator const (&acc)[MMAS_M][MMAS_N]) { + float const& scale_f = reinterpret_cast(this->params_scale_bmm1_); +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int col_idx = 0; col_idx < COLS_PER_THREAD; ++col_idx) { +#pragma unroll + for (int row_idx = 0; row_idx < ROWS_PER_THREAD; ++row_idx) { + // the order of the acc rf is we traverse vertically first + // then we traverse horizontally. + int elt_row = ROWS_PER_THREAD * mi + row_idx; + int elt_col = COLS_PER_THREAD * ELTS_PER_THREAD * ni + col_idx * ELTS_PER_THREAD; + + float elt0 = acc[mi][ni].elt(col_idx * 2 * ROWS_PER_THREAD + 2 * row_idx + 0) * scale_f; + float elt1 = acc[mi][ni].elt(col_idx * 2 * ROWS_PER_THREAD + 2 * row_idx + 1) * scale_f; + + // 1.0f / softcapping_scale has been fused to scale_bmm1. + if constexpr (ENABLE_BMM1_SOFTCAPPING_SCALE) { + elt0 = this->params_softcapping_scale_bmm1_ * __tanhf(elt0); + elt1 = this->params_softcapping_scale_bmm1_ * __tanhf(elt1); + } + + this->elt_[elt_row][elt_col + 0] = elt0; + this->elt_[elt_row][elt_col + 1] = elt1; + + } // row_idx + } // col_idx + } // ni + } // mi + } + + // Store the tile after softmax. + template + inline __device__ void store(Gmem_tile& gmem_tile) { + Accumulator acc[MMAS_M][MMAS_N]; + +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int col_idx = 0; col_idx < COLS_PER_THREAD; ++col_idx) { +#pragma unroll + for (int row_idx = 0; row_idx < ROWS_PER_THREAD; ++row_idx) { + // the order of the acc rf is we traverse vertically first + // then we traverse horizontally + int elt_row = ROWS_PER_THREAD * mi + row_idx; + int elt_col = COLS_PER_THREAD * ELTS_PER_THREAD * ni + col_idx * ELTS_PER_THREAD; + float elt0 = this->elt_[elt_row][elt_col + 0]; + float elt1 = this->elt_[elt_row][elt_col + 1]; + + acc[mi][ni].elt(col_idx * 2 * ROWS_PER_THREAD + 2 * row_idx + 0) = elt0; + acc[mi][ni].elt(col_idx * 2 * ROWS_PER_THREAD + 2 * row_idx + 1) = elt1; + } // row_idx + } // col_idx + } // ni + } // m + + // Delegate to the gmem tile to store. + gmem_tile.store(acc); + } + + // Pack the data to a fragment for the next GEMM. + template + inline __device__ void pack(Fragment_a (&dst)[K][M]) const { +// we know the instruction shape is 64xNx16 +// Thus for input A matrix, it is of size 64x16 per warpgroup. +// Thus, each threads access 2 rows and 4 columns. contiguous 2 columns are held by 1 RF. +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ki = 0; ki < K; ++ki) { + // 1st row - 4 elements per row. + float tmp_00 = this->elt_[2 * mi + 0][4 * ki + 0]; + float tmp_01 = this->elt_[2 * mi + 0][4 * ki + 1]; + float tmp_02 = this->elt_[2 * mi + 0][4 * ki + 2]; + float tmp_03 = this->elt_[2 * mi + 0][4 * ki + 3]; + + // 2nd row - 4 elements per row. + float tmp_10 = this->elt_[2 * mi + 1][4 * ki + 0]; + float tmp_11 = this->elt_[2 * mi + 1][4 * ki + 1]; + float tmp_12 = this->elt_[2 * mi + 1][4 * ki + 2]; + float tmp_13 = this->elt_[2 * mi + 1][4 * ki + 3]; + + // Pack to 4 registers. + dst[ki][mi].reg(0) = fmha::float2_to_half2(tmp_00, tmp_01); + dst[ki][mi].reg(1) = fmha::float2_to_half2(tmp_10, tmp_11); + dst[ki][mi].reg(2) = fmha::float2_to_half2(tmp_02, tmp_03); + dst[ki][mi].reg(3) = fmha::float2_to_half2(tmp_12, tmp_13); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax, + Cta_tile_, Kernel_traits_> + : public Softmax_gmma_base< + fmha::Hopper_hgmma_bf16_traits, Cta_tile_, + Kernel_traits_, Cta_tile_::WARPS_N> { + // The traits. + using Traits = fmha::Hopper_hgmma_bf16_traits; + // Cta_tile. + using Cta_tile = Cta_tile_; + // Kernel_traits. + using Kernel_traits = Kernel_traits_; + // The Base class. + using Base = Softmax_gmma_base; + // The accumulators. + using Accumulator = typename Base::Accumulator; + // The Mma tile. + using Mma_tile = typename Base::Mma_tile; + + // The number of MMAs in M/N dimensions. + enum { MMAS_M = Mma_tile::MMAS_M }; + + enum { MMAS_N = Mma_tile::MMAS_N }; + + // for HGMMA_FP16, there are 2 elements per RF for ACC. + enum { ELTS_PER_THREAD = 2 }; + + // for Hopper HGMMA, each row is held within 4 threads. + enum { THREADS_PER_ROW = 4 }; + + // The number of rows accessed by each thread. + enum { + ROWS_PER_THREAD = + Traits::GMMA_M / (Cta_tile::THREADS_PER_WARP / THREADS_PER_ROW) / Cta_tile::WARPS_M + }; + + // The number of columns access by each thread. + // Note there are 2 elements per reg. + enum { COLS_PER_THREAD = Traits::GMMA_N / THREADS_PER_ROW / ELTS_PER_THREAD }; + + // Use BMM1 softcapping scale or not. + enum { ENABLE_BMM1_SOFTCAPPING_SCALE = Kernel_traits::ENABLE_BMM1_SOFTCAPPING_SCALE }; + + // Ctor. + template + inline __device__ Softmax(Params const& params, void* smem, int bidb, int tidx) + : Base(params, smem, bidb, tidx) {} + + // Convert from FP16 fragments to floats. + inline __device__ void unpack(Accumulator const (&acc)[MMAS_M][MMAS_N]) { + float const& scale_f = reinterpret_cast(this->params_scale_bmm1_); +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int col_idx = 0; col_idx < COLS_PER_THREAD; ++col_idx) { +#pragma unroll + for (int row_idx = 0; row_idx < ROWS_PER_THREAD; ++row_idx) { + // the order of the acc rf is we traverse vertically first + // then we traverse horizontally. + int elt_row = ROWS_PER_THREAD * mi + row_idx; + int elt_col = COLS_PER_THREAD * ELTS_PER_THREAD * ni + col_idx * ELTS_PER_THREAD; + + float elt0 = acc[mi][ni].elt(col_idx * 2 * ROWS_PER_THREAD + 2 * row_idx + 0) * scale_f; + float elt1 = acc[mi][ni].elt(col_idx * 2 * ROWS_PER_THREAD + 2 * row_idx + 1) * scale_f; + + if constexpr (ENABLE_BMM1_SOFTCAPPING_SCALE) { + elt0 = this->params_softcapping_scale_bmm1_ * __tanhf(elt0); + elt1 = this->params_softcapping_scale_bmm1_ * __tanhf(elt1); + } + + this->elt_[elt_row][elt_col + 0] = elt0; + this->elt_[elt_row][elt_col + 1] = elt1; + + } // row_idx + } // col_idx + } // ni + } // mi + } + + // Store the tile after softmax. + template + inline __device__ void store(Gmem_tile& gmem_tile) { + Accumulator acc[MMAS_M][MMAS_N]; + +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int col_idx = 0; col_idx < COLS_PER_THREAD; ++col_idx) { +#pragma unroll + for (int row_idx = 0; row_idx < ROWS_PER_THREAD; ++row_idx) { + // the order of the acc rf is we traverse vertically first + // then we traverse horizontally. + int elt_row = ROWS_PER_THREAD * mi + row_idx; + int elt_col = COLS_PER_THREAD * ELTS_PER_THREAD * ni + col_idx * ELTS_PER_THREAD; + float elt0 = this->elt_[elt_row][elt_col + 0]; + float elt1 = this->elt_[elt_row][elt_col + 1]; + + acc[mi][ni].elt(col_idx * 2 * ROWS_PER_THREAD + 2 * row_idx + 0) = elt0; + acc[mi][ni].elt(col_idx * 2 * ROWS_PER_THREAD + 2 * row_idx + 1) = elt1; + } // row_idx + } // col_idx + } // ni + } // m + + // Delegate to the gmem tile to store. + gmem_tile.store(acc); + } + + // Pack the data to a fragment for the next GEMM. + template + inline __device__ void pack(Fragment_a (&dst)[K][M]) const { +// we know the instruction shape is 64xNx16 +// Thus for input A matrix, it is of size 64x16 per warpgroup. +// Thus, each threads access 2 rows and 4 columns. contiguous 2 columns are held by 1 RF. +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ki = 0; ki < K; ++ki) { + // 1st row - 4 elements per row. + float tmp_00 = this->elt_[2 * mi + 0][4 * ki + 0]; + float tmp_01 = this->elt_[2 * mi + 0][4 * ki + 1]; + float tmp_02 = this->elt_[2 * mi + 0][4 * ki + 2]; + float tmp_03 = this->elt_[2 * mi + 0][4 * ki + 3]; + + // 2nd row - 4 elements per row. + float tmp_10 = this->elt_[2 * mi + 1][4 * ki + 0]; + float tmp_11 = this->elt_[2 * mi + 1][4 * ki + 1]; + float tmp_12 = this->elt_[2 * mi + 1][4 * ki + 2]; + float tmp_13 = this->elt_[2 * mi + 1][4 * ki + 3]; + + // Pack to 4 registers. + dst[ki][mi].reg(0) = fmha::float2_to_bf16_x2(tmp_00, tmp_01); + dst[ki][mi].reg(1) = fmha::float2_to_bf16_x2(tmp_10, tmp_11); + dst[ki][mi].reg(2) = fmha::float2_to_bf16_x2(tmp_02, tmp_03); + dst[ki][mi].reg(3) = fmha::float2_to_bf16_x2(tmp_12, tmp_13); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax_gmma_32bit_8bit_base + : public Softmax_gmma_base { + // The Base class. + using Base = Softmax_gmma_base; + // The accumulators. + using Accumulator = typename Base::Accumulator; + // The Mma tile. + using Mma_tile = typename Base::Mma_tile; + + // The number of MMAs in M/N dimensions. + enum { MMAS_M = Mma_tile::MMAS_M }; + + enum { MMAS_N = Mma_tile::MMAS_N }; + + // TODO these should be general. + // Two elts per thread per acc core matrix. + enum { ELTS_PER_THREAD = 2 }; + + // Number of threads per row of the acc core matrix. + enum { THREADS_PER_ROW = 4 }; + + // The number of rows accessed by each thread per GMMA. + enum { + ROWS_PER_THREAD = + Traits::GMMA_M / (Cta_tile::THREADS_PER_WARP / THREADS_PER_ROW) / Cta_tile::WARPS_M + }; + + // The number of columns access by each thread. + enum { COLS_PER_THREAD = Traits::GMMA_N / THREADS_PER_ROW / ELTS_PER_THREAD }; + + // Check the expected number of accumulator elements. + static_assert(Accumulator::NUM_ELTS == COLS_PER_THREAD * ROWS_PER_THREAD * ELTS_PER_THREAD); + + // Ctor. + template + inline __device__ Softmax_gmma_32bit_8bit_base(Params const& params, void* smem, int bidb, + int tidx) + : Base(params, smem, bidb, tidx) {} + + inline __device__ void unpack(Accumulator const (&acc)[MMAS_M][MMAS_N]) { + float const scalef = reinterpret_cast(this->params_scale_bmm1_); +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < COLS_PER_THREAD; ++ii) { + float tmp_00 = + acc[mi][ni].elt(ii * ROWS_PER_THREAD * ELTS_PER_THREAD + 0 * ELTS_PER_THREAD + 0) * + scalef; + float tmp_01 = + acc[mi][ni].elt(ii * ROWS_PER_THREAD * ELTS_PER_THREAD + 0 * ELTS_PER_THREAD + 1) * + scalef; + float tmp_10 = + acc[mi][ni].elt(ii * ROWS_PER_THREAD * ELTS_PER_THREAD + 1 * ELTS_PER_THREAD + 0) * + scalef; + float tmp_11 = + acc[mi][ni].elt(ii * ROWS_PER_THREAD * ELTS_PER_THREAD + 1 * ELTS_PER_THREAD + 1) * + scalef; + int n_offset = ni * COLS_PER_THREAD * ELTS_PER_THREAD + ii * ELTS_PER_THREAD; + this->elt_[mi * ROWS_PER_THREAD + 0][n_offset + 0] = tmp_00; + this->elt_[mi * ROWS_PER_THREAD + 0][n_offset + 1] = tmp_01; + this->elt_[mi * ROWS_PER_THREAD + 1][n_offset + 0] = tmp_10; + this->elt_[mi * ROWS_PER_THREAD + 1][n_offset + 1] = tmp_11; + } // ii + } // ni + } // mi + } + + inline __device__ void unpack_noscale(Accumulator const (&acc)[MMAS_M][MMAS_N]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < COLS_PER_THREAD; ++ii) { + float tmp_00 = + acc[mi][ni].elt(ii * ROWS_PER_THREAD * ELTS_PER_THREAD + 0 * ELTS_PER_THREAD + 0); + float tmp_01 = + acc[mi][ni].elt(ii * ROWS_PER_THREAD * ELTS_PER_THREAD + 0 * ELTS_PER_THREAD + 1); + float tmp_10 = + acc[mi][ni].elt(ii * ROWS_PER_THREAD * ELTS_PER_THREAD + 1 * ELTS_PER_THREAD + 0); + float tmp_11 = + acc[mi][ni].elt(ii * ROWS_PER_THREAD * ELTS_PER_THREAD + 1 * ELTS_PER_THREAD + 1); + int n_offset = ni * COLS_PER_THREAD * ELTS_PER_THREAD + ii * ELTS_PER_THREAD; + this->elt_[mi * ROWS_PER_THREAD + 0][n_offset + 0] = tmp_00; + this->elt_[mi * ROWS_PER_THREAD + 0][n_offset + 1] = tmp_01; + this->elt_[mi * ROWS_PER_THREAD + 1][n_offset + 0] = tmp_10; + this->elt_[mi * ROWS_PER_THREAD + 1][n_offset + 1] = tmp_11; + } // ii + } // ni + } // mi + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax, + Cta_tile, Kernel_traits> + : public Softmax_gmma_32bit_8bit_base< + fmha::Hopper_qgmma_e4m3_fp32_traits, + Cta_tile, Kernel_traits> { + // The traits. + using Traits = fmha::Hopper_qgmma_e4m3_fp32_traits; + // The Base class. + using Base = Softmax_gmma_32bit_8bit_base; + + using Accumulator = typename Base::Accumulator; + + enum { + MMAS_M = Base::MMAS_M, + MMAS_N = Base::MMAS_N, + ROWS_PER_THREAD = Base::ROWS_PER_THREAD, + COLS_PER_THREAD = Base::COLS_PER_THREAD, + ELTS_PER_THREAD = Base::ELTS_PER_THREAD, + }; + + // Ctor. + template + inline __device__ Softmax(Params const& params, void* smem, int bidb, int tidx) + : Base(params, smem, bidb, tidx), params_scale_softmax_(params.scale_softmax) {} + + // Store the tile after softmax. + template + inline __device__ void store(Gmem_tile& gmem_tile) { + float const scale = reinterpret_cast(this->params_scale_softmax_); + + Accumulator acc[MMAS_M][MMAS_N]; + +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < COLS_PER_THREAD; ++ii) { + int row = mi * ROWS_PER_THREAD; + int col = ni * COLS_PER_THREAD * ELTS_PER_THREAD + ii * ELTS_PER_THREAD; + float tmp_00 = this->elt_[row + 0][col + 0] * scale; + float tmp_01 = this->elt_[row + 0][col + 1] * scale; + float tmp_10 = this->elt_[row + 1][col + 0] * scale; + float tmp_11 = this->elt_[row + 1][col + 1] * scale; + + int elt_idx = ii * ROWS_PER_THREAD * ELTS_PER_THREAD; + acc[mi][ni].elt(elt_idx + 0 * ELTS_PER_THREAD + 0) = tmp_00; + acc[mi][ni].elt(elt_idx + 0 * ELTS_PER_THREAD + 1) = tmp_01; + acc[mi][ni].elt(elt_idx + 1 * ELTS_PER_THREAD + 0) = tmp_10; + acc[mi][ni].elt(elt_idx + 1 * ELTS_PER_THREAD + 1) = tmp_11; + } // ii + } // ni + } // mi + + // Delegate to the gmem tile to store. + gmem_tile.store(acc); + } + + // Pack the data to a fragment for the next GEMM. + template + inline __device__ void pack(Fragment_a (&dst)[K][M]) const { + static_assert(M == 1); + static_assert(Fragment_a::NUM_REGS == 4); + static_assert(Fragment_a::NUM_ELTS == 16); + // Acc per warp: 16 x 256 FP32 + // A is 8 times(in K) 16 x 32 FP8, i.e. 4 registers per thread. + + static_assert(MMAS_N * COLS_PER_THREAD * ELTS_PER_THREAD % 8 == 0); + static_assert(MMAS_N * COLS_PER_THREAD * ELTS_PER_THREAD == K * Fragment_a::NUM_ELTS / 2); + + float const scale = reinterpret_cast(this->params_scale_softmax_); + +// The canonical layout in K should be R0: [0,1,2,3] R2: [16,17,18,19] +// Note below that this is not possible with the register layout of the accumulator. +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ki = 0; ki < K; ++ki) { + // 1st row - 8 elements per row. + float tmp_00 = this->elt_[2 * mi + 0][8 * ki + 0] * scale; // + 0 + float tmp_01 = this->elt_[2 * mi + 0][8 * ki + 1] * scale; // + 1 + float tmp_02 = this->elt_[2 * mi + 0][8 * ki + 2] * scale; // + 8 + float tmp_03 = this->elt_[2 * mi + 0][8 * ki + 3] * scale; // + 9 + float tmp_04 = this->elt_[2 * mi + 0][8 * ki + 4] * scale; // +16 + float tmp_05 = this->elt_[2 * mi + 0][8 * ki + 5] * scale; // +17 + float tmp_06 = this->elt_[2 * mi + 0][8 * ki + 6] * scale; // +24 + float tmp_07 = this->elt_[2 * mi + 0][8 * ki + 7] * scale; // +25 + + // 2nd row - 4 elements per row. + float tmp_10 = this->elt_[2 * mi + 1][8 * ki + 0] * scale; // + 0 + float tmp_11 = this->elt_[2 * mi + 1][8 * ki + 1] * scale; // + 1 + float tmp_12 = this->elt_[2 * mi + 1][8 * ki + 2] * scale; // + 8 + float tmp_13 = this->elt_[2 * mi + 1][8 * ki + 3] * scale; // + 9 + float tmp_14 = this->elt_[2 * mi + 1][8 * ki + 4] * scale; // +16 + float tmp_15 = this->elt_[2 * mi + 1][8 * ki + 5] * scale; // +17 + float tmp_16 = this->elt_[2 * mi + 1][8 * ki + 6] * scale; // +24 + float tmp_17 = this->elt_[2 * mi + 1][8 * ki + 7] * scale; // +25 + + // Pack to 4 registers. + dst[ki][mi].reg(0) = fmha::float4_to_fp8x4(tmp_00, tmp_01, tmp_02, tmp_03); + dst[ki][mi].reg(1) = fmha::float4_to_fp8x4(tmp_10, tmp_11, tmp_12, tmp_13); + dst[ki][mi].reg(2) = fmha::float4_to_fp8x4(tmp_04, tmp_05, tmp_06, tmp_07); + dst[ki][mi].reg(3) = fmha::float4_to_fp8x4(tmp_14, tmp_15, tmp_16, tmp_17); + } + } + } + + uint32_t const params_scale_softmax_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax, + Cta_tile, Kernel_traits> + : public Softmax_gmma_32bit_8bit_base< + fmha::Hopper_igmma_int8_int32_traits, + Cta_tile, Kernel_traits> { + // The traits. + using Traits = fmha::Hopper_igmma_int8_int32_traits; + + // The Base class. + using Base = Softmax_gmma_32bit_8bit_base; + + using Accumulator = typename Base::Accumulator; + + enum { + MMAS_M = Base::MMAS_M, + MMAS_N = Base::MMAS_N, + ROWS_PER_THREAD = Base::ROWS_PER_THREAD, + COLS_PER_THREAD = Base::COLS_PER_THREAD, + ELTS_PER_THREAD = Base::ELTS_PER_THREAD, + }; + + // Ctor. + template + inline __device__ Softmax(Params const& params, void* smem, int bidb, int tidx) + : Base(params, smem, bidb, tidx), params_scale_softmax_(params.scale_softmax) {} + + // Store the tile after softmax. + template + inline __device__ void store(Gmem_tile& gmem_tile) { + float const scale = reinterpret_cast(this->params_scale_softmax_); + Accumulator acc[MMAS_M][MMAS_N]; + +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < COLS_PER_THREAD; ++ii) { + int n_offset = ni * COLS_PER_THREAD * ELTS_PER_THREAD + ii * ELTS_PER_THREAD; + float tmp_00 = this->elt_[mi * ROWS_PER_THREAD + 0][n_offset + 0]; + float tmp_01 = this->elt_[mi * ROWS_PER_THREAD + 0][n_offset + 1]; + float tmp_10 = this->elt_[mi * ROWS_PER_THREAD + 1][n_offset + 0]; + float tmp_11 = this->elt_[mi * ROWS_PER_THREAD + 1][n_offset + 1]; + + int elt_offset = ii * ROWS_PER_THREAD * ELTS_PER_THREAD; + acc[mi][ni].elt(elt_offset + 0 * ELTS_PER_THREAD + 0) = tmp_00 * scale; + acc[mi][ni].elt(elt_offset + 0 * ELTS_PER_THREAD + 1) = tmp_01 * scale; + acc[mi][ni].elt(elt_offset + 1 * ELTS_PER_THREAD + 0) = tmp_10 * scale; + acc[mi][ni].elt(elt_offset + 1 * ELTS_PER_THREAD + 1) = tmp_11 * scale; + } // ii + } // ni + } // mi + + // Delegate to the gmem tile to store. + gmem_tile.store(acc); + } + + // Pack the data to a fragment for the next GEMM. + template + inline __device__ void pack(Fragment_a (&dst)[K][M]) const { + static_assert(M == 1); + static_assert(Fragment_a::NUM_REGS == 4); + static_assert(Fragment_a::NUM_ELTS == 16); + // Acc per warp: 16 x 256 FP32 + // A is 8 times(in K) 16 x 32 FP8, i.e. 4 registers per thread. + + static_assert(MMAS_N * COLS_PER_THREAD * ELTS_PER_THREAD % 8 == 0); + static_assert(MMAS_N * COLS_PER_THREAD * ELTS_PER_THREAD == K * Fragment_a::NUM_ELTS / 2); + + float const scale = reinterpret_cast(this->params_scale_softmax_); +// The canonical layout in K should be R0: [0,1,2,3] R2: [16,17,18,19] +// Note below that this is not possible with the register layout of the accumulator. +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ki = 0; ki < K; ++ki) { + // 1st row - 8 elements per row. + float tmp_00 = this->elt_[2 * mi + 0][8 * ki + 0] * scale; // + 0 + float tmp_01 = this->elt_[2 * mi + 0][8 * ki + 1] * scale; // + 1 + float tmp_02 = this->elt_[2 * mi + 0][8 * ki + 2] * scale; // + 8 + float tmp_03 = this->elt_[2 * mi + 0][8 * ki + 3] * scale; // + 9 + float tmp_04 = this->elt_[2 * mi + 0][8 * ki + 4] * scale; // +16 + float tmp_05 = this->elt_[2 * mi + 0][8 * ki + 5] * scale; // +17 + float tmp_06 = this->elt_[2 * mi + 0][8 * ki + 6] * scale; // +24 + float tmp_07 = this->elt_[2 * mi + 0][8 * ki + 7] * scale; // +25 + + // 2nd row - 4 elements per row. + float tmp_10 = this->elt_[2 * mi + 1][8 * ki + 0] * scale; // + 0 + float tmp_11 = this->elt_[2 * mi + 1][8 * ki + 1] * scale; // + 1 + float tmp_12 = this->elt_[2 * mi + 1][8 * ki + 2] * scale; // + 8 + float tmp_13 = this->elt_[2 * mi + 1][8 * ki + 3] * scale; // + 9 + float tmp_14 = this->elt_[2 * mi + 1][8 * ki + 4] * scale; // +16 + float tmp_15 = this->elt_[2 * mi + 1][8 * ki + 5] * scale; // +17 + float tmp_16 = this->elt_[2 * mi + 1][8 * ki + 6] * scale; // +24 + float tmp_17 = this->elt_[2 * mi + 1][8 * ki + 7] * scale; // +25 + + // Pack to 4 registers. + dst[ki][mi].reg(0) = fmha::float4_to_char4(tmp_00, tmp_01, tmp_02, tmp_03); + dst[ki][mi].reg(1) = fmha::float4_to_char4(tmp_10, tmp_11, tmp_12, tmp_13); + dst[ki][mi].reg(2) = fmha::float4_to_char4(tmp_04, tmp_05, tmp_06, tmp_07); + dst[ki][mi].reg(3) = fmha::float4_to_char4(tmp_14, tmp_15, tmp_16, tmp_17); + } + } + } + + uint32_t const params_scale_softmax_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// The softmax normalization statistics used by flash attention (l, m) +template +struct Softmax_statistics { + // The shape of the MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of MMAs in the M dimension. + enum { MMAS_M = Mma_tile::MMAS_M }; + + // Ctor. + template + inline __device__ Softmax_statistics(Params const& params, void const* ptr, Binfo const& binfo, + int tidx) + : ptr_(reinterpret_cast(ptr)), seqlen_(binfo.actual_seqlen) { + // The decomposition of the thread index into warp/lane. + int warp = tidx / Cta_tile::THREADS_PER_WARP; + int lane = tidx % Cta_tile::THREADS_PER_WARP; + + // The position of the the warp in the CTA. + int warp_m = warp % Cta_tile::WARPS_M; + + // The position of the thread + token_ = warp_m * Mma_tile::M_PER_MMA + lane / 4; + + // Compute the offset to the first token of the sequence. + int64_t offset = binfo.bidb * params.h + binfo.bidh; + // Move the pointer to the correct position. + ptr_ += offset * params.lse_stride_in_bytes; + } + + // Load the bias into registers (and expand). + inline __device__ void load(int step) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ii = 0; ii < 2; ++ii) { + // The index of the token. + int token = token_; + // At each iteration we jump over STEPQ elements. + token += step * Cta_tile::M; + // The extra offset inside the CTA. + token += mi * Mma_tile::M_PER_MMA_PER_CTA + (ii & 0x1) * 8; + + // Fetch the value if the token is valid. + float val = 0.0f; + if (token < seqlen_) { + val = reinterpret_cast(ptr_)[token]; + } + lm_[2 * mi + ii] = val; + } + } + } + + // The pointer to the bias. + int8_t const* ptr_; + // The length of the sequence. + int const seqlen_; + // The token that this thread is loading. + int token_; + // The bias after expansion. + float lm_[MMAS_M * 2]; +}; + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/traits.h b/csrc/fmha_v2/fmha/traits.h new file mode 100644 index 0000000000..bb6f4b700d --- /dev/null +++ b/csrc/fmha_v2/fmha/traits.h @@ -0,0 +1,942 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include + +#include "fmha/numeric_types.h" + +#define FMHA_DIV_UP(m, n) (((m) + (n) - 1) / (n)) + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Trait class for heuristically determining the tile sizes +template +struct Traits_tile_size; + +template +struct Traits_tile_size { + enum { + CTA_P_TILE_M = STEP, + CTA_P_TILE_N = S, + CTA_P_TILE_K = D, + CTA_O_TILE_M = CTA_P_TILE_M, + CTA_O_TILE_N = DV, + CTA_O_TILE_K = S + }; +}; + +template +struct Traits_tile_size { + enum { + CTA_P_TILE_M = STEP, + CTA_P_TILE_N = S, + // D =16: CTA_P_TILE_K=16 + // D =32: CTA_P_TILE_K=32 + // D>=64: CTA_P_TILE_K=64 + CTA_P_TILE_K = D < 32 ? 16 : (D < 64 ? 32 : 64), + CTA_O_TILE_M = CTA_P_TILE_M, + // D =512: CTA_TILE_N=256 + // D<=256: CTA_TILE_N=D + CTA_O_TILE_N = DV > 256 ? 256 : DV, + // D =512: CTA_O_TILE_K=16 + // D =256: CTA_O_TILE_K=32 + // D<=128: CTA_O_TILE_K=64 + CTA_O_TILE_K = std::max(K_PER_MMA, DV > 256 ? 16 : (DV > 128 ? 32 : 64)) + }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The GPU architecture. + typename Gpu_arch, + // The number of rows in the CTA tile. + int M_, + // The number of cols in the CTA tile. + int N_, + // The number of elements in the the K dimension of the GEMM loop. + int K_, + // The number of valid cols in the CTA tile. + int VALID_N_, + // The number of valid elements in the the K dimension of the GEMM loop. + int VALID_K_, + // The number of rows of warps. + int WARPS_M_, + // The number of cols of warps. + int WARPS_N_, + // The number of warps in the K dimension of the GEMM loop. + int WARPS_K_> +struct Cta_tile_ { + enum { M = M_, N = N_, K = K_, VALID_N = VALID_N_, VALID_K = VALID_K_ }; + + // The number of warps. + enum { WARPS_M = WARPS_M_, WARPS_N = WARPS_N_, WARPS_K = WARPS_K_ }; + + // The number of warps per CTA. + enum { WARPS_PER_CTA = WARPS_M * WARPS_N * WARPS_K }; + + // The number of threads per warp. + enum { THREADS_PER_WARP = Gpu_arch::THREADS_PER_WARP }; + + // The number of threads per CTA. + enum { THREADS_PER_CTA = WARPS_PER_CTA * THREADS_PER_WARP }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The GPU architecture. + typename Gpu_arch_, + // The type of the elements of A. + typename A_type_, + // The type of the elements of B. + typename B_type_, + // The type of the elements of C. + typename C_type_, + // The type of the elements of the accumulators. + typename Accumulator_type_, + // The type of the elements of the epilogue. + typename Epilogue_type_> +struct Traits { + // The architecture. + using Gpu_arch = Gpu_arch_; + // The data type for A elements. + using A_type = A_type_; + // The data type for B elements. + using B_type = B_type_; + // The data type for C elements. + using C_type = C_type_; + // The data type for accumulators. + using Accumulator_type = Accumulator_type_; + // The data type of the math in the epilogue. + using Epilogue_type = Epilogue_type_; + + // Create the description of the CTA tile from a configuration. + template + using Cta_tile_extd = Cta_tile_; + + // The number of bits per element of A. + enum { BITS_PER_ELEMENT_A = sizeof(A_type) * 8 }; + + // An offset in bytes for A. + static inline __host__ __device__ int64_t offset_in_bytes_a(int64_t offset) { + return offset * static_cast(sizeof(A_type)); + } + + // The number of bits per element of B. + enum { BITS_PER_ELEMENT_B = sizeof(B_type) * 8 }; + + // An offset in bytes for B. + static inline __host__ __device__ int64_t offset_in_bytes_b(int64_t offset) { + return offset * static_cast(sizeof(B_type)); + } + + // The number of bits per element of C. + enum { BITS_PER_ELEMENT_C = sizeof(C_type) * 8 }; + + // An offset in bytes for C. + static inline __host__ __device__ int64_t offset_in_bytes_c(int64_t offset) { + return offset * static_cast(sizeof(C_type)); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Gpu_arch_base { + // By default, architectures have 32 threads per warp. + enum { THREADS_PER_WARP = 32 }; + + // By default, architectures do not support LDGSTS. + enum { HAS_LDGSTS = 0 }; + + // By default, architecture do not support super HMMA + enum { HAS_SUPER_HMMA = 0 }; + + // By default, architecture do not support TMA + enum { HAS_TMA = 0 }; + + // By default, architecture do not support GMMA + enum { HAS_GMMA = 0 }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +using Cta_tile_with_k_with_padding = typename Traits_::template Cta_tile_extd< + Cta_tile_::M, Cta_tile_::N, Next_power_of_two::VALUE, Cta_tile_::N, + Next_power_of_two::VALUE, Cta_tile_::WARPS_M, Cta_tile_::WARPS_N, + Cta_tile_::WARPS_K>; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Volta : public Gpu_arch_base {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Volta_mma_tile { + // The number of elements computed with a single warp-MMA. + enum { M_PER_MMA = 16, N_PER_MMA = N_PER_MMA_, K_PER_MMA = K_PER_MMA_ }; + + // The number of elements computed with a single CTA-MMA. + enum { + M_PER_MMA_PER_CTA = M_PER_MMA * Cta_tile::WARPS_M, + N_PER_MMA_PER_CTA = N_PER_MMA * Cta_tile::WARPS_N, + K_PER_MMA_PER_CTA = K_PER_MMA * Cta_tile::WARPS_K + }; + + // The number of MMAs needed to compute the GEMM. + enum { + MMAS_M = (Cta_tile::M + M_PER_MMA_PER_CTA - 1) / M_PER_MMA_PER_CTA, + MMAS_N = (Cta_tile::N + N_PER_MMA_PER_CTA - 1) / N_PER_MMA_PER_CTA, + MMAS_K = (Cta_tile::K + K_PER_MMA_PER_CTA - 1) / K_PER_MMA_PER_CTA + }; + + // The number of valid MMAs (for Head Size) + enum { + // tile o + VALID_MMAS_N = Div_up::VALUE, + // tile p + VALID_MMAS_K = Div_up::VALUE, + }; + + // The number of elements computed per warp. + enum { + M_PER_WARP = MMAS_M * M_PER_MMA, + N_PER_WARP = MMAS_N * N_PER_MMA, + K_PER_WARP = MMAS_K * K_PER_MMA, + }; + + // Do we enable the fast path for LDS. + enum { ENABLE_LDS_FAST_PATH = 0 }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Volta_hmma_fp16_traits + : public Traits { + // The K_PER_MMA for Volta_hmma_fp16_traits is 8. + enum { K_PER_MMA = 8 }; + + // The MMA tile. + template + using Mma_tile = Volta_mma_tile; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Volta_hmma_fp16_16x16x16_traits + : public Traits { + // The K_PER_MMA for Volta_hmma_fp16_16x16x16_traits is 16. + enum { K_PER_MMA = 16 }; + + // The MMA tile. + template + using Mma_tile = Volta_mma_tile; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Volta_imma_int8_int32_traits : public Traits { + // The K_PER_MMA for Volta_imma_int8_int32_traits is 16. + enum { K_PER_MMA = 16 }; + + // The MMA tile. + template + using Mma_tile = Volta_mma_tile; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Turing : public Gpu_arch_base {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Turing_mma_tile { + // The number of elements computed with a single warp-MMA. + enum { M_PER_MMA = 16, N_PER_MMA = 16, K_PER_MMA = K_PER_MMA_ }; + + // The number of elements computed with a single CTA-MMA. + enum { + M_PER_MMA_PER_CTA = M_PER_MMA * Cta_tile::WARPS_M, + N_PER_MMA_PER_CTA = N_PER_MMA * Cta_tile::WARPS_N, + K_PER_MMA_PER_CTA = K_PER_MMA * Cta_tile::WARPS_K + }; + + // The number of MMAs needed to compute the GEMM. + enum { + MMAS_M = Div_up::VALUE, + MMAS_N = Div_up::VALUE, + MMAS_K = Div_up::VALUE, + }; + + // The number of valid MMAs (for Head Size) + enum { + // tile o + VALID_MMAS_N = Div_up::VALUE, + // tile p + VALID_MMAS_K = Div_up::VALUE, + }; + + // The number of elements computed per warp. + enum { + M_PER_WARP = MMAS_M * M_PER_MMA, + N_PER_WARP = MMAS_N * N_PER_MMA, + K_PER_WARP = MMAS_K * K_PER_MMA, + }; + + // The distribution of threads in the output tile. + enum { + THREADS_PER_MMA_M = 8, + THREADS_PER_MMA_N = 4, + }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Turing_hmma_tile : public Turing_mma_tile {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Turing_hmma_fp16_traits + : public Traits { + // The K_PER_MMA for Turing_hmma_fp16_traits is 8. + enum { K_PER_MMA = 8 }; + + // The MMA tile. + template + using Mma_tile = Turing_hmma_tile; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Turing_hmma_fp32_traits : public Traits { + // The K_PER_MMA for Turing_hmma_fp32_traits is 8. + enum { K_PER_MMA = 8 }; + + // The MMA tile. + template + using Mma_tile = Turing_hmma_tile; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Turing_imma_int8_tile : public Turing_mma_tile {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Turing_imma_int8_int32_traits + : public Traits { + // The K_PER_MMA for Turing_imma_int8_int32_traits is 16. + enum { K_PER_MMA = 16 }; + + // The MMA tile. + template + using Mma_tile = Turing_imma_int8_tile; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Ampere : public Gpu_arch_base { + // It has LDGSTS. + enum { HAS_LDGSTS = 1 }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Ampere_hmma_tile : public Turing_mma_tile {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Ampere_hmma_fp16_traits + : public Traits { + // The K_PER_MMA for Ampere_hmma_fp16_traits is 16. + enum { K_PER_MMA = 16 }; + + // The MMA tile. + template + using Mma_tile = Ampere_hmma_tile; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Ampere_hmma_fp32_traits + : public Traits { + // The K_PER_MMA for Ampere_hmma_fp32_traits is 16. + enum { K_PER_MMA = 16 }; + + // The MMA tile. + template + using Mma_tile = Ampere_hmma_tile; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// used for Epilogue_type = bf16_t (similar to Ampere_hmma_fp16_traits). +struct Ampere_hmma_bf16_bf16_traits + : public Traits { + // The K_PER_MMA for Ampere_hmma_bf16_bf16_traits is 16. + enum { K_PER_MMA = 16 }; + + // The MMA tile. + template + using Mma_tile = Ampere_hmma_tile; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Ampere_hmma_bf16_traits : public Traits { + // The K_PER_MMA for Ampere_hmma_bf16_traits is 16. + enum { K_PER_MMA = 16 }; + + // The MMA tile. + template + using Mma_tile = Ampere_hmma_tile; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Ampere_imma_int8_tile : public Turing_mma_tile {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Ampere_imma_int8_int32_traits + : public Traits { + // The K_PER_MMA for Ampere_imma_int8_int32_traits is 32. + enum { K_PER_MMA = 32 }; + + // The MMA tile. + template + using Mma_tile = Ampere_imma_int8_tile; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Ada : public Gpu_arch_base { + // It has LDGSTS. + enum { HAS_LDGSTS = 1 }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// The following partial traits are mapped to Ampere_hmma_fp16_traits in fmha/kernel_traits.h. +// +// It is easier to implement setup.py this way. +struct Ada_hmma_fp16_traits {}; + +struct Ada_hmma_fp32_traits {}; + +struct Ada_imma_int8_int32_traits {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Ada_qmma_fp8_tile : public Turing_mma_tile {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Ada_qmma_e4m3_fp16_traits : public Traits { + // The K_PER_MMA for Ada_qmma_e4m3_fp16_traits is 32. + enum { K_PER_MMA = 32 }; + + // The MMA tile. + template + using Mma_tile = Ada_qmma_fp8_tile; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Ada_qmma_e4m3_fp32_traits : public Traits { + // The K_PER_MMA for Ada_qmma_e4m3_fp32_traits is 32. + enum { K_PER_MMA = 32 }; + + // The MMA tile. + template + using Mma_tile = Ada_qmma_fp8_tile; + + static constexpr float SOFTMAX_FP_QUANT_SCALE = Softmax_fp_quant_scale(); + static constexpr float SOFTMAX_FP_DEQUANT_SCALE = 1.f / SOFTMAX_FP_QUANT_SCALE; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Hopper : public Gpu_arch_base { + // It has LDGSTS. + enum { HAS_LDGSTS = 1 }; + + // It has TMA. + enum { HAS_TMA = 1 }; + + // It has GMMA + enum { HAS_GMMA = 1 }; + + // for Hopper there are 4 warps per warpgroup. + enum { WARPS_PER_WARP_GROUP = 4 }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Hopper related code. +// SHOULD we move this to a different file?? +//////////////////////////////////////////////////////////////////////////////////////////////////// +template +struct Hopper_cga_tile { + // The size of the CGA in terms of CTA + enum { CLUSTER_HEIGHT = HEIGHT_ }; + + enum { CLUSTER_WIDTH = WIDTH_ }; + + enum { CLUSTER_DEPTH = DEPTH_ }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +template // Number of warp group along K dim +struct Hopper_cta_tile { + // GPU arch. + using Gpu_arch = Gpu_arch_; + + // The size of the CTA tile. + // TODO: support D (not power of 2) + enum { M = M_, N = N_, K = K_, VALID_N = VALID_N_, VALID_K = VALID_K_ }; + + // The number of warp groups. + enum { WARP_GROUP_M = WARP_GROUP_M_, WARP_GROUP_N = WARP_GROUP_N_, WARP_GROUP_K = WARP_GROUP_K_ }; + + // The number of warps in a warp group. + enum { + WARPS_M_PER_GROUP = 4, + WARPS_N_PER_GROUP = 1, + WARPS_K_PER_GROUP = 1, + }; + + // The number of warps in a cta. + enum { + WARPS_M = WARPS_M_PER_GROUP * WARP_GROUP_M_, + WARPS_N = WARPS_N_PER_GROUP * WARP_GROUP_N_, + WARPS_K = WARPS_K_PER_GROUP * WARP_GROUP_K_ + }; + + // The number of warps per CTA. + enum { + WARPS_PER_CTA = WARP_GROUP_M * WARP_GROUP_N * WARP_GROUP_K * Gpu_arch::WARPS_PER_WARP_GROUP + }; + + // The number of warps per warpgroup. + enum { WARPS_PER_WARP_GROUP = Gpu_arch::WARPS_PER_WARP_GROUP }; + + // The number of threads per warp. + enum { THREADS_PER_WARP = Gpu_arch::THREADS_PER_WARP }; + + // the number of threads per warpgroup. + enum { THREADS_PER_WARP_GROUP = THREADS_PER_WARP * WARPS_PER_WARP_GROUP }; + + // The number of threads per CTA. + enum { THREADS_PER_CTA = WARPS_PER_CTA * THREADS_PER_WARP }; + + enum { GROUPS_M = 1 }; + + enum { GROUPS_N = 1 }; + + enum { GROUPS_K = 1 }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hopper_gmma_tile { + // The number of elements computed with a single warp group mma. + enum { M_PER_MMA = GMMA_M, N_PER_MMA = GMMA_N, K_PER_MMA = GMMA_K }; + + // The number of warp groups. + enum { + NUM_WARP_GROUPS = Cta_tile::WARP_GROUP_M * Cta_tile::WARP_GROUP_N * Cta_tile::WARP_GROUP_K + }; + + // The number of elements computed with a single CTA-MMA. + enum { + M_PER_MMA_PER_CTA = M_PER_MMA * Cta_tile::WARP_GROUP_M, + N_PER_MMA_PER_CTA = N_PER_MMA * Cta_tile::WARP_GROUP_N, + K_PER_MMA_PER_CTA = K_PER_MMA * Cta_tile::WARP_GROUP_K + }; + + // The number of MMAs needed to compute the GEMM. + enum { + MMAS_M = (Cta_tile::M + M_PER_MMA_PER_CTA - 1) / M_PER_MMA_PER_CTA, + MMAS_N = (Cta_tile::N + N_PER_MMA_PER_CTA - 1) / N_PER_MMA_PER_CTA, + MMAS_K = (Cta_tile::K + K_PER_MMA_PER_CTA - 1) / K_PER_MMA_PER_CTA, + }; + + // The number of valid MMAs (for Head Size) + enum { + // tile o + VALID_MMAS_N = Div_up::VALUE, + // tile p + VALID_MMAS_K = Div_up::VALUE, + }; + + // The number of elements computed per warp group. + enum { + M_PER_WARP_GROUP = MMAS_M * M_PER_MMA, + N_PER_WARP_GROUP = MMAS_N * N_PER_MMA, + K_PER_WARP_GROUP = MMAS_K * K_PER_MMA, + }; + + // the size of GMMA group, which is GMMA_M x GMMA_N x Kblock. + enum { + M_PER_GMMA_GROUP = GMMA_M, + N_PER_GMMA_GROUP = GMMA_N, + K_PER_GMMA_GROUP = Cta_tile::K, + }; + + // The distribution of threads in the output tile. + // TODO + enum { + THREADS_PER_MMA_M = 8, + THREADS_PER_MMA_N = 4, + }; + + // The number of core matrices per GMMA. + enum { + CORES_M_PER_GROUP = 8 * Cta_tile::WARPS_M_PER_GROUP, + CORES_N_PER_GROUP = 8 * Cta_tile::WARPS_N_PER_GROUP, + CORES_M = GMMA_M / CORES_M_PER_GROUP, + CORES_N = GMMA_N / CORES_N_PER_GROUP, + }; + + // The number of logical rows/cols per thread. + enum { + // A thread owns 1 row per core matrix. + ROWS_PER_THREAD = CORES_M, + // A thread owns 2 col per core matrix. + COLS_PER_THREAD = CORES_N * 2, + }; + + static_assert(ROWS_PER_THREAD == 2); + static_assert(COLS_PER_THREAD == GMMA_N / 4); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +enum class Hopper_instructions { + HGMMA_FP16, + HGMMA_BF16, + HGMMA_FP32, + IGMMA_INT32, + QGMMA_E4M3_FP32, + QGMMA_E5M2_FP32, + QGMMA_E4M3_FP16, + QGMMA_E5M2_FP16 +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Hopper HGMMA FP16 Traits +template +struct Hopper_hgmma_fp16_traits + : public Traits { + // The GMMA shape. + enum { GMMA_M = GMMA_M_, GMMA_N = GMMA_N_, GMMA_K = 16 }; + + // is A operand in RF for GMMA? + static constexpr bool GMMA_A_RF = GMMA_A_RF_; + + // is B operand in RF for GMMA? + static constexpr bool GMMA_B_RF = GMMA_B_RF_; + + // GMMA shape has certain requirements. + static_assert(GMMA_K == 16, "GMMA K must be 16; this might change"); + static_assert(GMMA_M == 64, "GMMA M must be 64; this might change"); + static_assert(GMMA_N % 8 == 0, "GMMA N must be multiple of 8; this might change"); + static_assert(GMMA_N <= 256, "GMMA N must be no larger than 256; this might change"); + + // GMMA does not allow both operands coming from RF. + static_assert((GMMA_A_RF && GMMA_B_RF) != true, + "GMMA does not allow both operands coming from RF."); + + // The Cta tile. + template + using Cta_tile = Hopper_cta_tile; + + // The Cta tile. + template + using Cta_padded_tile = + Hopper_cta_tile; + + // The CGA Tile + template + using Cga_tile = Hopper_cga_tile; + + // The MMA tile. + template + using Mma_tile = Hopper_gmma_tile; + + // The handle to differentiate instructions. + static constexpr fmha::Hopper_instructions HOPPER_INSTRUCTION = + fmha::Hopper_instructions::HGMMA_FP16; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Hopper HGMMA FP32 Traits +template +struct Hopper_hgmma_fp32_traits + : public Traits { + // The GMMA shape. + enum { GMMA_M = GMMA_M_, GMMA_N = GMMA_N_, GMMA_K = 16 }; + + // is A operand in RF for GMMA? + static constexpr bool GMMA_A_RF = GMMA_A_RF_; + + // is B operand in RF for GMMA? + static constexpr bool GMMA_B_RF = GMMA_B_RF_; + + // GMMA shape has certain requirements. + static_assert(GMMA_K == 16, "GMMA K must be 16; this might change"); + static_assert(GMMA_M == 64, "GMMA M must be 64; this might change"); + static_assert(GMMA_N % 8 == 0, "GMMA N must be multiple of 8; this might change"); + static_assert(GMMA_N <= 256, "GMMA N must be no larger than 256; this might change"); + + // GMMA does not allow both operands coming from RF. + static_assert((GMMA_A_RF && GMMA_B_RF) != true, + "GMMA does not allow both operands coming from RF."); + + // The Cta tile. + template + using Cta_tile = Hopper_cta_tile; + + // The Cta tile. + template + using Cta_padded_tile = + Hopper_cta_tile; + + // The CGA Tile + template + using Cga_tile = Hopper_cga_tile; + + // The MMA tile. + template + using Mma_tile = Hopper_gmma_tile; + + // The handle to differentiate instructions. + static constexpr fmha::Hopper_instructions HOPPER_INSTRUCTION = + fmha::Hopper_instructions::HGMMA_FP32; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Hopper BF16 HGMMA Traits +template +struct Hopper_hgmma_bf16_traits : public Traits { + // The GMMA shape. + enum { GMMA_M = GMMA_M_, GMMA_N = GMMA_N_, GMMA_K = 16 }; + + // is A operand in RF for GMMA? + static constexpr bool GMMA_A_RF = GMMA_A_RF_; + + // is B operand in RF for GMMA? + static constexpr bool GMMA_B_RF = GMMA_B_RF_; + + // GMMA shape has certain requirements. + static_assert(GMMA_K == 16, "GMMA K must be 16; this might change"); + static_assert(GMMA_M == 64, "GMMA M must be 64; this might change"); + static_assert(GMMA_N % 8 == 0, "GMMA N must be multiple of 8; this might change"); + static_assert(GMMA_N <= 256, "GMMA N must be no larger than 256; this might change"); + + // GMMA does not allow both operands coming from RF. + static_assert((GMMA_A_RF && GMMA_B_RF) != true, + "GMMA does not allow both operands coming from RF."); + + // The Cta tile. + template + using Cta_tile = Hopper_cta_tile; + + // The Cta tile. + template + using Cta_padded_tile = + Hopper_cta_tile; + + // The CGA Tile + template + using Cga_tile = Hopper_cga_tile; + + // The MMA tile. + template + using Mma_tile = Hopper_gmma_tile; + + // The handle to differentiate instructions. + static constexpr fmha::Hopper_instructions HOPPER_INSTRUCTION = + fmha::Hopper_instructions::HGMMA_BF16; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Hopper IGMMA Traits +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hopper_igmma_int8_int32_traits + : public Traits { + using Base = Traits; + + // The GMMA shape + enum { GMMA_M = GMMA_M_ }; + + enum { GMMA_N = GMMA_N_ }; + + enum { GMMA_K = 32 }; + + // is A operand in RF for GMMA? + static constexpr bool GMMA_A_RF = GMMA_A_RF_; + + // is B operand in RF for GMMA? + static constexpr bool GMMA_B_RF = GMMA_B_RF_; + + // GMMA shape has certain requirement + static_assert(GMMA_K == 32, "GMMA K must be 32; this might change"); + static_assert(GMMA_M == 64, "GMMA M must be 64; this might change"); + static_assert(GMMA_N % 8 == 0, "GMMA N must be multiple of 8; this might change"); + static_assert(GMMA_N <= 256, "GMMA N must be no larger than 256; this might change"); + + // GMMA does not allow both operands coming from RF. + static_assert((GMMA_A_RF && GMMA_B_RF) != true, + "GMMA does not allow both operands coming from RF."); + + // The Cta tile. + template + using Cta_tile = Hopper_cta_tile; + + // The Cta tile. + template + using Cta_padded_tile = + Hopper_cta_tile; + + // The CGA Tile + template + using Cga_tile = Hopper_cga_tile; + + // The MMA tile. + template + using Mma_tile = Hopper_gmma_tile; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Hopper QGMMA Traits +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hopper_qgmma_fp8_fp32_traits + : public Traits { + using Base = Traits; + + using Input_type_A = Input_type_A_; + using Input_type_B = Input_type_B_; + using Output_type = Output_type_; + + // The GMMA shape + enum { GMMA_M = GMMA_M_ }; + + enum { GMMA_N = GMMA_N_ }; + + enum { GMMA_K = 32 }; + + // is A operand in RF for GMMA? + static constexpr bool GMMA_A_RF = GMMA_A_RF_; + + // is B operand in RF for GMMA? + static constexpr bool GMMA_B_RF = GMMA_B_RF_; + + // GMMA shape has certain requirement + static_assert(GMMA_K == 32, "GMMA K must be 32; this might change"); + static_assert(GMMA_M == 64, "GMMA M must be 64; this might change"); + static_assert(GMMA_N % 8 == 0, "GMMA N must be multiple of 8; this might change"); + static_assert(GMMA_N <= 256, "GMMA N must be no larger than 256; this might change"); + + // GMMA does not allow both operands coming from RF. + static_assert((GMMA_A_RF && GMMA_B_RF) != true, + "GMMA does not allow both operands coming from RF."); + + // The Cta tile. + template + using Cta_tile = Hopper_cta_tile; + + // The Cta tile. + template + using Cta_padded_tile = + Hopper_cta_tile; + + // The CGA Tile + template + using Cga_tile = Hopper_cga_tile; + + // The XMMA tile. + template + using Mma_tile = Hopper_gmma_tile; + + // Used by low precision floating point types (e4m3, e5m2, etc.) + static constexpr float SOFTMAX_FP_QUANT_SCALE = Softmax_fp_quant_scale(); + static constexpr float SOFTMAX_FP_DEQUANT_SCALE = 1.f / SOFTMAX_FP_QUANT_SCALE; +}; + +template +using Hopper_qgmma_e4m3_fp32_traits = + Hopper_qgmma_fp8_fp32_traits; + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/utils.h b/csrc/fmha_v2/fmha/utils.h new file mode 100644 index 0000000000..f65d2fe661 --- /dev/null +++ b/csrc/fmha_v2/fmha/utils.h @@ -0,0 +1,2355 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include +#include +#include +#include + +#if defined(__CLANGD__) +#include <__clang_cuda_builtin_vars.h> +#include <__clang_cuda_math.h> +#endif + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#include +#endif + +// include warpgroup related instructions, used by SM90. +#include +// include gmma related instructions, used by SM90. +#include +// include tma related instructions, used by SM90. +#include + +#include "fmha/numeric_types.h" + +#define FP32_I2F_MAGIC_NUMBER 12582912.f +#define FP32_I2F_MAGIC_NUMBER_HEX 0x4b400000 + +extern "C" __device__ uint32_t __nvvm_get_smem_pointer(void* ptr); + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace introspection { + +template +struct Unpack; + +template +struct Unpack { + // if we simply static_assert(false) then compiler will not emit template params upon failure + static_assert(N < INT_MIN, ""); + using Type = std::integral_constant; +}; + +template +struct Unpack { + using Type = Unpack; + using Unpack_first = typename Unpack::Type; + using Unpack_remaining = typename Unpack::Type; +}; + +} // namespace introspection + +// Example usage: +// +// Inspect_ns<(int)USE_LDGSTS_, PRED_REGS, (int)IS_HOPPER> foo; +// +// or +// +// Inspect_ns<(int)USE_LDGSTS_, PRED_REGS, (int)IS_HOPPER>{}.foo(); +// +// Output by nvcc: +// +// ./src/fmha/gmem_tile_qkv_packed.h(70): error: static assertion failed with "" +// detected during: +// instantiation of class "fmha::v2::Unpack [with N=1]" +// (77): here +// instantiation of class "fmha::v2::Unpack [with N=1, Ns=<2, 0>]" +// (84): here +// instantiation of class "fmha::v2::Inspect_ns [with Ns=<1, 2, 0>]" +// (143): here +template +struct Inspect_ns { + using Type = typename introspection::Unpack::Type; +}; + +// Can be used alongside with static_assert() to figure out the conditions when assertion failed +// Example: +// +// Cond_inspect_ns< (int)ROWS >= (int)ROWS_PER_LDG, ROWS, ROWS_PER_LDG> foo; +// +// Output by nvcc (when condition is not met): +// +// ./src/fmha/utils.h(163): error: static assertion failed with "" +// detected during: +// instantiation of class "Cond_inspect_ns [with COND=false, Ns=<32, +// 64>]" +template +struct Cond_inspect_ns { + static_assert(COND, ""); +}; + +// Example: +// +// Inspect_type{}.foo(); +// +// or +// +// Inspect_type foo; +// +// Output by nvcc: +// +// ./src/fmha/utils.h(189): error: class "fmha::Ampere_hmma_tile, 16>" has no member "Dummy" +// detected during: +// instantiation of class "Inspect_type [with +// T=fmha::Ampere_hmma_tile, 16>]" +template +struct Inspect_type { + // Purposefully trigger error by referencing non-existent T::Dummy + using Dummy = typename T::Dummy; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Row { + static constexpr bool COL = false; + static constexpr bool ROW = true; +}; + +struct Col { + static constexpr bool COL = true; + static constexpr bool ROW = false; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Round_up { + enum { VALUE = (M + N - 1) / N * N }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Tile_nhw { + enum { N = N_, H = H_, W = W_ }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Next_power_of_two {}; + +template +struct Next_power_of_two { + enum { VALUE = M }; +}; + +template <> +struct Next_power_of_two<3, false> { + enum { VALUE = 4 }; +}; + +template <> +struct Next_power_of_two<5, false> { + enum { VALUE = 8 }; +}; + +template <> +struct Next_power_of_two<6, false> { + enum { VALUE = 8 }; +}; + +template <> +struct Next_power_of_two<7, false> { + enum { VALUE = 8 }; +}; + +template <> +struct Next_power_of_two<9, false> { + enum { VALUE = 16 }; +}; + +template <> +struct Next_power_of_two<10, false> { + enum { VALUE = 16 }; +}; + +template <> +struct Next_power_of_two<11, false> { + enum { VALUE = 16 }; +}; + +template <> +struct Next_power_of_two<12, false> { + enum { VALUE = 16 }; +}; + +template <> +struct Next_power_of_two<13, false> { + enum { VALUE = 16 }; +}; + +template <> +struct Next_power_of_two<14, false> { + enum { VALUE = 16 }; +}; + +template <> +struct Next_power_of_two<15, false> { + enum { VALUE = 16 }; +}; + +template <> +struct Next_power_of_two<24, false> { + enum { VALUE = 32 }; +}; + +template <> +struct Next_power_of_two<40, false> { + enum { VALUE = 64 }; +}; + +template <> +struct Next_power_of_two<48, false> { + enum { VALUE = 64 }; +}; + +template <> +struct Next_power_of_two<72, false> { + enum { VALUE = 128 }; +}; + +template <> +struct Next_power_of_two<80, false> { + enum { VALUE = 128 }; +}; + +template <> +struct Next_power_of_two<96, false> { + enum { VALUE = 128 }; +}; + +template <> +struct Next_power_of_two<104, false> { + enum { VALUE = 128 }; +}; + +template <> +struct Next_power_of_two<112, false> { + enum { VALUE = 128 }; +}; + +template <> +struct Next_power_of_two<144, false> { + enum { VALUE = 256 }; +}; + +template <> +struct Next_power_of_two<160, false> { + enum { VALUE = 256 }; +}; + +template <> +struct Next_power_of_two<192, false> { + enum { VALUE = 256 }; +}; + +template <> +struct Next_power_of_two<576, false> { + enum { VALUE = 1024 }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Prev_power_of_two {}; + +template +struct Prev_power_of_two { + enum { VALUE = N }; +}; + +template <> +struct Prev_power_of_two<3, false> { + enum { VALUE = 2 }; +}; + +template <> +struct Prev_power_of_two<5, false> { + enum { VALUE = 4 }; +}; + +template <> +struct Prev_power_of_two<6, false> { + enum { VALUE = 4 }; +}; + +template <> +struct Prev_power_of_two<7, false> { + enum { VALUE = 4 }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Compute_skew { + // The size of a transaction. + enum { BYTES_PER_TRX = 128 }; + + // The remainder of the row without skew. + enum { REMAINDER = BYTES_PER_ROW % BYTES_PER_TRX }; + + // The value. + enum { VALUE = REMAINDER <= SKEW ? SKEW - REMAINDER : BYTES_PER_TRX + SKEW - REMAINDER }; + + // Make sure the math works ;) + static_assert((BYTES_PER_ROW + VALUE) % BYTES_PER_TRX == SKEW, ""); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Compute_skew { + // No skew! + enum { VALUE = 0 }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Div_up { + enum { VALUE = (M + N - 1) / N }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Max { + enum { VALUE = A >= B ? A : B }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Max_3 { + enum { VALUE = Max::VALUE, C>::VALUE }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Min { + enum { VALUE = A <= B ? A : B }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Uint_from_size_in_bytes {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Uint_from_size_in_bytes<1> { + using Type = uint8_t; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Uint_from_size_in_bytes<2> { + using Type = uint16_t; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Uint_from_size_in_bytes<4> { + using Type = uint32_t; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Uint_from_size_in_bytes<8> { + using Type = uint2; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Uint_from_size_in_bytes<16> { + using Type = uint4; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Warp_masks {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Warp_masks<8, 1, 1> { + enum { M = 0xe0, N = 0x00, K = 0x00 }; +}; + +template <> +struct Warp_masks<4, 2, 1> { + enum { M = 0x60, N = 0x80, K = 0x00 }; +}; + +template <> +struct Warp_masks<4, 1, 2> { + enum { M = 0x60, N = 0x00, K = 0x80 }; +}; + +template <> +struct Warp_masks<4, 1, 1> { + enum { M = 0x60, N = 0x00, K = 0x00 }; +}; + +template <> +struct Warp_masks<2, 4, 1> { + enum { M = 0x20, N = 0xc0, K = 0x00 }; +}; + +template <> +struct Warp_masks<2, 2, 2> { + enum { M = 0x20, N = 0x40, K = 0x80 }; +}; + +template <> +struct Warp_masks<2, 2, 1> { + enum { M = 0x20, N = 0x40, K = 0x00 }; +}; + +template <> +struct Warp_masks<2, 1, 2> { + enum { M = 0x20, N = 0x00, K = 0x40 }; +}; + +template <> +struct Warp_masks<2, 1, 1> { + enum { M = 0x20, N = 0x00, K = 0x00 }; +}; + +template <> +struct Warp_masks<1, 8, 1> { + enum { M = 0x00, N = 0xe0, K = 0x00 }; +}; + +template <> +struct Warp_masks<1, 4, 2> { + enum { M = 0x00, N = 0x60, K = 0x80 }; +}; + +template <> +struct Warp_masks<1, 4, 1> { + enum { M = 0x00, N = 0x60, K = 0x00 }; +}; + +template <> +struct Warp_masks<1, 2, 2> { + enum { M = 0x00, N = 0x20, K = 0x40 }; +}; + +template <> +struct Warp_masks<1, 2, 1> { + enum { M = 0x00, N = 0x20, K = 0x00 }; +}; + +template <> +struct Warp_masks<1, 1, 4> { + enum { M = 0x00, N = 0x00, K = 0x60 }; +}; + +template <> +struct Warp_masks<1, 1, 2> { + enum { M = 0x00, N = 0x00, K = 0x20 }; +}; + +template <> +struct Warp_masks<1, 1, 1> { + enum { M = 0x00, N = 0x00, K = 0x00 }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ __host__ T div_up(T m, T n) { + return (m + n - 1) / n; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline int clz(int x) { + for (int i = 31; i >= 0; --i) { + if ((1 << i) & x) { + return 31 - i; + } + } + return 32; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline int find_log_2(int x, bool round_up = false) { + int a = 31 - clz(x); + if (round_up) { + a += (x & (x - 1)) ? 1 : 0; + } + return a; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline void find_divisor(uint32_t& mul, uint32_t& shr, int x) { + assert(x != 0); + if (x == 1) { + // If dividing by 1, reduced math doesn't work because mul_coeff would need to be 2^32, + // which doesn't fit into unsigned int. the div() routine handles this special case + // separately. + mul = 0; + shr = 0; + } else { + // To express the division N/D in terms of a multiplication, what we first + // imagine is simply N*(1/D). However, 1/D will always evaluate to 0 (for D>1), + // so we need another way. There's nothing that says we have to use exactly + // the fraction 1/D; instead it could be any X/Y that reduces to 1/D (i.e., + // Y=X*D), or at least to "close enough" to it. If we pick Y that is a power + // of two, then the N*(X/Y) can be N*X followed by a right-shift by some amount. + // The power of two we should pick should be at least 2^32, because in the + // div() routine we'll use umulhi(), which returns only the upper 32 bits -- + // this being equivalent to a right-shift by 32. But we might want a higher + // power of two for better accuracy depending on the magnitude of the denominator. + // Once we've picked Y, then X [our mul_coeff value] is simply Y/D, rounding up, + // and we save shift_coeff as whatever further shift we have to do beyond + // what the umulhi() implies. + uint32_t p = 31 + find_log_2(x, true); + uint32_t m = (uint32_t)(((1ull << p) + (uint32_t)x - 1) / (uint32_t)x); + + mul = m; + shr = p - 32; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void fast_divmod(int& div, int& mod, int x, int y, uint32_t mul, uint32_t shr) { + if (y == 1) { + div = x; + mod = 0; + } else { + div = __umulhi((uint32_t)x, mul) >> shr; + mod = x - div * y; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t hadd2(uint32_t a, uint32_t b) { + uint32_t c; + asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t bfadd2(uint32_t a, uint32_t b) { + uint32_t c; + uint32_t one = 0x3f803f80; + ; + asm volatile("fma.rn.bf16x2 %0, %1, %3, %2;\n" : "=r"(c) : "r"(a), "r"(b), "r"(one)); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t hmax2(uint32_t a, uint32_t b) { + uint32_t c; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("max.f16x2 %0, %1, %2;" : "=r"(c) : "r"(a), "r"(b)); +#else + asm volatile( + "{\n" + "\t .reg .f16x2 sela, selb;\n" + "\n" + "\t set.ge.f16x2.f16x2 sela, %1, %2;\n" + "\t set.gt.f16x2.f16x2 selb, %2, %1;\n" + "\n" + "\t mul.f16x2 %0, sela, %1;\n" + "\t fma.rn.f16x2 %0, selb, %2, %0;\n" + "}\n" + : "=r"(c) + : "r"(a), "r"(b)); +#endif + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint2 hmax4(uint2 a, uint2 b) { + uint2 c; + c.x = hmax2(a.x, b.x); + c.y = hmax2(a.y, b.y); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint4 hmax8(uint4 a, uint4 b) { + uint4 c; + c.x = hmax2(a.x, b.x); + c.y = hmax2(a.y, b.y); + c.z = hmax2(a.z, b.z); + c.w = hmax2(a.w, b.w); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t hmin2(uint32_t a, uint32_t b) { + uint32_t c; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("min.f16x2 %0, %1, %2;" : "=r"(c) : "r"(a), "r"(b)); +#else + asm volatile( + "{\n" + "\t .reg .f16x2 sela, selb;\n" + "\n" + "\t set.le.f16x2.f16x2 sela, %1, %2;\n" + "\t set.lt.f16x2.f16x2 selb, %2, %1;\n" + "\n" + "\t mul.f16x2 %0, sela, %1;\n" + "\t fma.rn.f16x2 %0, selb, %2, %0;\n" + "}\n" + : "=r"(c) + : "r"(a), "r"(b)); +#endif + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t hmul2(uint32_t a, uint32_t b) { + uint32_t c; + asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t bfmul2(uint32_t a, uint32_t b) { + uint32_t c; + asm("{.reg .b32 c;\n" + " mov.b32 c, 0x80008000U;\n" + " fma.rn.bf16x2 %0,%1,%2,c;}\n" + : "=r"(c) + : "r"(a), "r"(b)); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint2 hmul4(uint2 a, uint2 b) { + uint2 c; + c.x = hmul2(a.x, b.x); + c.y = hmul2(a.y, b.y); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint4 hmul8(uint4 a, uint4 b) { + uint4 c; + c.x = hmul2(a.x, b.x); + c.y = hmul2(a.y, b.y); + c.z = hmul2(a.z, b.z); + c.w = hmul2(a.w, b.w); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint4 hmul8(uint32_t a, uint4 b) { + uint4 c; + c.x = hmul2(a, b.x); + c.y = hmul2(a, b.y); + c.z = hmul2(a, b.z); + c.w = hmul2(a, b.w); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Template function to support both half and bfloat16 +template +inline __device__ uint32_t mul2(uint32_t a, uint32_t b) { + return hmul2(a, b); +} + +template <> +inline __device__ uint32_t mul2(uint32_t a, uint32_t b) { + return bfmul2(a, b); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Template function to support both half and bfloat16 +template +inline __device__ uint4 mul8(uint32_t a, uint4 b) { + uint4 c; + c.x = hmul2(a, b.x); + c.y = hmul2(a, b.y); + c.z = hmul2(a, b.z); + c.w = hmul2(a, b.w); + return c; +} + +template <> +inline __device__ uint4 mul8(uint32_t a, uint4 b) { + uint4 c; + c.x = bfmul2(a, b.x); + c.y = bfmul2(a, b.y); + c.z = bfmul2(a, b.z); + c.w = bfmul2(a, b.w); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t hrelu2(uint32_t x) { + uint32_t res; + uint32_t const zero = 0u; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); +#else + asm volatile( + "{\n" + "\t .reg .f16x2 sela;\n" + "\t set.gtu.u32.f16x2 sela, %1, %2;\n" + "\t and.b32 %0, sela, %1;\n" + "}\n" + : "=r"(res) + : "r"(x), "r"(zero)); +#endif + return res; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t bfrelu2(uint32_t x) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + uint32_t res; + uint32_t const zero = 0u; + asm volatile("max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); + return res; +#endif + // not implemented yet + return x; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Template function to support both half and bfloat16 +template +inline __device__ uint32_t relu2(uint32_t x) { + return hrelu2(x); +} + +template <> +inline __device__ uint32_t relu2(uint32_t x) { + return bfrelu2(x); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t habs2(uint32_t x) { + uint32_t res; + asm volatile("abs.f16x2 %0, %1;\n" : "=r"(res) : "r"(x)); + return res; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// static inline __device__ uint32_t add_bias(uint32_t a, uint32_t bias, bool relu) { +// uint32_t c; +// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +// if( relu ) { +// uint32_t one = 0x3c003c00u; +// asm volatile("fma.rn.f16x2.relu %0, %1, %2, %3;" : "=r"(c) : "r"(a), "r"(one), +// "r"(bias)); +// } else { +// c = hadd2(a, bias); +// } +// #else +// c = hadd2(a, bias); +// if( relu ) { +// c = hrelu2(c); +// } +// #endif +// return c; +// } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// static inline __device__ uint2 add_bias(uint2 a, uint2 bias, bool relu) { +// uint2 dst; +// dst.x = add_bias(a.x, bias.x, relu); +// dst.y = add_bias(a.y, bias.y, relu); +// return dst; +// } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// static inline __device__ uint4 add_bias(uint4 a, uint4 bias, bool relu) { +// uint4 dst; +// dst.x = add_bias(a.x, bias.x, relu); +// dst.y = add_bias(a.y, bias.y, relu); +// dst.z = add_bias(a.z, bias.z, relu); +// dst.w = add_bias(a.w, bias.w, relu); +// return dst; +// } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// clamp float +inf/-inf +static inline __device__ float satfinite(float x) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 860 + // bit representation of maximum value of float + uint32_t clamp_value = 0x7f7fffffu; + asm volatile("min.xorsign.abs.f32 %0, %0, %1;" : "+f"(x) : "r"(clamp_value)); + return x; +#else + // bit representation of maximum and minimum value of float + uint32_t umax = 0x7f7fffffu; + uint32_t umin = 0xff7fffffu; + float out; + asm volatile("min.f32 %0, %1, %2;" : "=f"(out) : "f"(x), "r"(umax)); + asm volatile("max.f32 %0, %0, %1;" : "+f"(out) : "r"(umin)); + return out; +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// clamp half2 +inf/-inf +static inline __device__ uint32_t satfinite_h2(uint32_t h2) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 860 + uint32_t out, clamp_value; + clamp_value = 0x7bff7bffu; + asm volatile("min.xorsign.abs.f16x2 %0, %1, %2;" : "=r"(out) : "r"(h2), "r"(clamp_value)); + return out; +#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 800 + // bit representation of maximum and minimum value of half2 + uint32_t umax = 0x7bff7bffu; + uint32_t umin = 0xfbfffbffu; + uint32_t out; + asm volatile("min.f16x2 %0, %1, %2;" : "=r"(out) : "r"(h2), "r"(umax)); + asm volatile("max.f16x2 %0, %0, %1;" : "+r"(out) : "r"(umin)); + return out; +#else + // Take the absolute value of h2. It should map to |Rx| in SASS. + uint32_t p2; + asm volatile("abs.f16x2 %0, %1;" : "=r"(p2) : "r"(h2)); + + // Compute a mask for each fp16: 0xffff if +INF and 0x0000 otherwise. + uint32_t inf2 = 0x7c007c00u; + uint32_t mask; + asm volatile("set.eq.u32.f16x2 %0, %1, %2;" : "=r"(mask) : "r"(p2), "r"(inf2)); + + // Recreate the new value. 0x7bff is the max value for FP16. + p2 = (~mask & p2) | (mask & 0x7bff7bff); + + // Simply re-add the sign and we're done. + return p2 | (h2 & 0x80008000); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +static inline __device__ T clamp(T x, T lb, T ub) { + return x < lb ? lb : (x > ub ? ub : x); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float custom_exp2f(float x, float scale, float scaled_max) { + float d1, d2; + asm("fma.rz.ftz.f32 %0, %1, %2, %3;" : "=f"(d1) : "f"(x), "f"(scale), "f"(-scaled_max)); + asm("ex2.approx.ftz.f32 %0, %1;" : "=f"(d2) : "f"(d1)); + return d2; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint16_t clamp_to_zero(uint16_t x) { + uint16_t mask; + asm volatile("set.gtu %0, %1, 0;" : "=h"(mask) : "h"(x)); + return mask & x; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint16_t float_to_half(float f) { + uint16_t h; + asm volatile("cvt.rn.f16.f32 %0, %1;" : "=h"(h) : "f"(f)); + return h; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ bf16_t float_to_bf16(float f) { return __float2bfloat16(f); } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t float2_to_half2(float a, float b) { + uint32_t c; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(c) : "f"(b), "f"(a)); +#else + uint16_t lo = float_to_half(a); + uint16_t hi = float_to_half(b); + asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(c) : "h"(lo), "h"(hi)); +#endif + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t float2_to_bf16_x2(float a, float b) { + uint32_t c; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n" : "=r"(c) : "f"(b), "f"(a)); +#else + uint16_t* px = reinterpret_cast(&a); + uint16_t* py = reinterpret_cast(&b); + uint16_t value = px[1]; + uint16_t value2 = py[1]; + + if (px[0] == 0x8000) { + if ((value & 0x1) == 1) value++; + } else if (px[0] > 0x8000) { + value++; + } + + if (py[0] == 0x8000) { + if ((value2 & 0x1) == 1) value2++; + } else if (py[0] > 0x8000) { + value2++; + } + + uint32_t high = reinterpret_cast(value2); + c = (high << 16) | value; +#endif + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Template function to support both half and bfloat16 +template +inline __device__ uint32_t float2_to_16bit_2(float a, float b) { + return float2_to_half2(a, b); +} + +template <> +inline __device__ uint32_t float2_to_16bit_2(float a, float b) { + return float2_to_bf16_x2(a, b); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t float_to_half2(float a) { return float2_to_half2(a, a); } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t float2_to_half2(float2 const& f) { + return float2_to_half2(f.x, f.y); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t float_to_bf16_2(float a) { return float2_to_bf16_x2(a, a); } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint2 float4_to_half4(float x, float y, float z, float w) { + uint2 d; + d.x = float2_to_half2(x, y); + d.y = float2_to_half2(z, w); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Template function to support both half and bfloat16 +template +inline __device__ uint2 float4_to_16bit_x4(float x, float y, float z, float w) { + uint2 d; + d.x = float2_to_half2(x, y); + d.y = float2_to_half2(z, w); + return d; +} + +template <> +inline __device__ uint2 float4_to_16bit_x4(float x, float y, float z, float w) { + uint2 d; + d.x = float2_to_bf16_x2(x, y); + d.y = float2_to_bf16_x2(z, w); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t hfma2(uint32_t a, uint32_t b, uint32_t c) { + uint32_t d; + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t hfma2_relu(uint32_t a, uint32_t b, uint32_t c) { + uint32_t d; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("fma.rn.f16x2.relu %0, %1, %2, %3;" : "=r"(d) : "r"(a), "r"(b), "r"(c)); +#else + d = hrelu2(hfma2(a, b, c)); +#endif + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t h0_h0(uint32_t x) { + uint32_t y; + asm volatile("{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {lo, lo};}\n" + : "=r"(y) + : "r"(x)); + return y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ float h0_to_float(uint32_t h2) { + float f; + asm volatile( + "{\n" + ".reg .f16 lo, hi;\n" + "mov.b32 {lo, hi}, %1;\n" + "cvt.f32.f16 %0, lo;\n" + "}\n" + : "=f"(f) + : "r"(h2)); + return f; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t h1_h1(uint32_t x) { + uint32_t y; + asm volatile("{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {hi, hi};}\n" + : "=r"(y) + : "r"(x)); + return y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint16_t hadd(uint16_t a, uint16_t b) { + uint16_t d; + asm volatile("add.f16 %0, %1, %2;" : "=h"(d) : "h"(a), "h"(b)); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t hadd(uint32_t a, uint32_t b) { return hadd2(a, b); } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint2 hadd4(uint2 a, uint2 b) { + uint2 c; + c.x = hadd2(a.x, b.x); + c.y = hadd2(a.y, b.y); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint2 hadd(uint2 a, uint2 b) { return hadd4(a, b); } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint4 hadd8(uint4 a, uint4 b) { + uint4 c; + c.x = hadd2(a.x, b.x); + c.y = hadd2(a.y, b.y); + c.z = hadd2(a.z, b.z); + c.w = hadd2(a.w, b.w); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Template function to support both half and bfloat16 +template +inline __device__ uint4 add8(uint4 a, uint4 b) { + return hadd8(a, b); +} + +template <> +inline __device__ uint4 add8(uint4 a, uint4 b) { + uint4 c; + c.x = bfadd2(a.x, b.x); + c.y = bfadd2(a.y, b.y); + c.z = bfadd2(a.z, b.z); + c.w = bfadd2(a.w, b.w); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint4 fadd4(uint4 a, uint4 b) { + float4 c; + c.x = reinterpret_cast(a.x) + reinterpret_cast(b.x); + c.y = reinterpret_cast(a.y) + reinterpret_cast(b.y); + c.z = reinterpret_cast(a.z) + reinterpret_cast(b.z); + c.w = reinterpret_cast(a.w) + reinterpret_cast(b.w); + return reinterpret_cast(c); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint4 hadd(uint4 a, uint4 b) { return hadd8(a, b); } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ float half_to_float(uint16_t h) { + float f; + asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); + return f; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ float bf16_to_float(uint16_t h) { + float f; + asm volatile("mov.b32 %0, {0, %1};\n" : "=f"(f) : "h"(h)); + return f; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ float2 half2_to_float2(uint32_t x) { + uint16_t lo, hi; + asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(x)); + return make_float2(half_to_float(lo), half_to_float(hi)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ float2 bf16_2_to_float2(uint32_t x) { + float2 res; + asm volatile( + "{\n" + " .reg .b16 lo, hi;\n" + " mov.b32 {lo, hi}, %2;\n" + " mov.b32 %0, {0, lo};\n" + " mov.b32 %1, {0, hi};\n" + "}\n" + : "=f"(res.x), "=f"(res.y) + : "r"(x)); + return res; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Template function to support both half and bfloat16 +template +inline __device__ float2 convert_from_16bit_2(uint32_t x) { + return half2_to_float2(x); +} + +template <> +inline __device__ float2 convert_from_16bit_2(uint32_t x) { + return bf16_2_to_float2(x); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ void half2_to_float2(float& x, float& y, uint32_t h) { + float2 tmp = half2_to_float2(h); + x = tmp.x; + y = tmp.y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint16_t hfma(uint16_t a, uint16_t b, uint16_t c) { + uint16_t d; + asm volatile("fma.rn.f16 %0, %1, %2, %3;" : "=h"(d) : "h"(a), "h"(b), "h"(c)); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint16_t hmul(uint16_t a, uint16_t b) { + uint16_t d; + asm volatile("mul.f16 %0, %1, %2;" : "=h"(d) : "h"(a), "h"(b)); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Converted two half2's or bf162's into float, then take their dot product. +template +inline __device__ float fma2_in_float(uint32_t const a, uint32_t const b) { + float2 af = fmha::convert_from_16bit_2(a); + float2 bf = fmha::convert_from_16bit_2(b); + return af.x * bf.x + af.y * bf.y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Converted two vectors of 8 half's or bf16's into float, then take their dot product. +template +inline __device__ float fma8_in_float(uint4 const a, uint4 const b) { + float sum; + sum = fmha::fma2_in_float(a.x, b.x); + sum += fmha::fma2_in_float(a.y, b.y); + sum += fmha::fma2_in_float(a.z, b.z); + sum += fmha::fma2_in_float(a.w, b.w); + return sum; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ float sigmoid(float x) { return 1.f / (1.f + expf(-x)); } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void clear(uint16_t& dst) { dst = uint16_t(0); } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void clear(uint32_t& dst) { dst = 0u; } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void clear(uint2& dst) { dst = make_uint2(0u, 0u); } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void clear(uint4& dst) { dst = make_uint4(0u, 0u, 0u, 0u); } + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// P R E D I C A T E P A C K I N G +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +enum { BYTES_PER_REG = 4, PREDS_PER_BYTE = 4, PREDS_PER_REG = BYTES_PER_REG * PREDS_PER_BYTE }; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Compute_number_of_pred_regs { + enum { VALUE = Div_up::VALUE }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void pack_predicates(uint32_t (&preds)[M], uint32_t const (&p)[N]) { + // Make sure the values match. + static_assert(Compute_number_of_pred_regs::VALUE == M, ""); + + // The number of complete steps (where we use all the predicates in a byte). + enum { COMPLETE_BYTES = N / PREDS_PER_BYTE }; + + // Make sure we allocated enough predicate registers. + static_assert(Div_up::VALUE <= M, ""); + + // The remainder. + enum { REMAINDER = N - COMPLETE_BYTES * PREDS_PER_BYTE }; + + // Make sure we got the math right and the remainder is between 0 and 3. + static_assert(REMAINDER >= 0 && REMAINDER <= 3, ""); + + // The mask to extract the predicates. + enum { COMPLETE_MASK = (1 << PREDS_PER_BYTE) - 1 }; + + // Run complete steps. +#pragma unroll + for (int ii = 0; ii < M; ++ii) { + // The number of complete bytes for that register. Be careful it can be > than 4 ;) + int const COMPLETE = (N - ii * PREDS_PER_REG) / PREDS_PER_BYTE; + + // Pack the predicates in a register. + uint32_t reg = 0u; +#pragma unroll + for (int jj = 0; jj < 4; ++jj) { + // Early exit. + if (jj >= COMPLETE) { + break; + } + + // Prepare the array of predicates. + bool tmp[PREDS_PER_BYTE]; +#pragma unroll + for (int kk = 0; kk < PREDS_PER_BYTE; ++kk) { + tmp[kk] = p[ii * PREDS_PER_REG + jj * PREDS_PER_BYTE + kk] != 0; + } + + // Store the predicates. +#pragma unroll + for (int kk = 0; kk < PREDS_PER_BYTE; ++kk) { + if (tmp[kk]) { + reg |= 1u << (jj * 8 + kk); + } + } + } + + // Skip the rest of the code if we do not have a remainder. + if (COMPLETE < 4 && REMAINDER > 0) { + // The mask to extract the predicates. + enum { REMAINDER_MASK = (1 << REMAINDER) - 1 }; + + // Prepare the array of predicates. + bool tmp[PREDS_PER_BYTE]; +#pragma unroll + for (int jj = 0; jj < REMAINDER; ++jj) { + tmp[jj] = p[COMPLETE_BYTES * PREDS_PER_BYTE + jj] != 0; + } + + // Store the predicates. +#pragma unroll + for (int jj = 0; jj < REMAINDER; ++jj) { + if (tmp[jj]) { + reg |= 1u << (COMPLETE * 8 + jj); + } + } + } + + // Store the predicate register. + preds[ii] = reg; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ uint32_t pack_predicates(uint32_t const (&p)[N]) { + uint32_t tmp[1]; + pack_predicates(tmp, p); + return tmp[0]; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// G E N E R I C P R E D I C A T E D L D G S T S +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void ldgsts_(Functor& fct, uint32_t const (&preds)[M]) { + // The number of complete bytes (where we use all the predicates in a byte). + enum { COMPLETE = N / PREDS_PER_BYTE }; + + // Make sure we did allocate enough predicates. + static_assert(Div_up::VALUE <= M, ""); + + // The remainder. + enum { REMAINDER = N - COMPLETE * PREDS_PER_BYTE }; + + // Make sure we got the math right and the remainder is between 0 and 3. + static_assert(REMAINDER >= 0 && REMAINDER <= 3, ""); + + // The mask to extract the predicates. + enum { COMPLETE_MASK = (1 << PREDS_PER_BYTE) - 1 }; + +// Clear the fetch registers. +#pragma unroll + for (int ii = 0; ii < N; ++ii) { + fct.clear(ii); + } + + // Run complete steps. + bool p[PREDS_PER_BYTE]; +#pragma unroll + for (int ii = 0; ii < COMPLETE; ++ii) { + // The predicate. + uint32_t reg = preds[ii / BYTES_PER_REG]; + + // Extract the predicates. +#pragma unroll + for (int jj = 0; jj < PREDS_PER_BYTE; ++jj) { + uint32_t mask = 1u << (ii % BYTES_PER_REG * 8 + jj); + p[jj] = (reg & mask) != 0u; + } + +// Issue the loads. +#pragma unroll + for (int jj = 0; jj < PREDS_PER_BYTE; ++jj) { + fct.ldgsts(ii * PREDS_PER_BYTE + jj, p[jj]); + } + } + + // Skip the rest of the code if we do not have a remainder. + if (REMAINDER > 0) { + // The mask to extract the predicates. + enum { REMAINDER_MASK = (1 << REMAINDER) - 1 }; + + // The predicate register. + uint32_t reg = preds[COMPLETE / BYTES_PER_REG]; + + // Extract the predicates. +#pragma unroll + for (int jj = 0; jj < PREDS_PER_BYTE; ++jj) { + uint32_t mask = 1u << (COMPLETE % BYTES_PER_REG * 8 + jj); + p[jj] = (reg & mask) != 0u; + } + +// Issue the loads. +#pragma unroll + for (int ii = 0; ii < REMAINDER; ++ii) { + fct.ldgsts(COMPLETE * PREDS_PER_BYTE + ii, p[ii]); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void ldgsts_(Functor& fct, uint32_t preds) { + uint32_t tmp[1] = {preds}; + ldgsts_(fct, tmp); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// L D G +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldg(uint8_t& dst, void const* ptr) { + dst = *reinterpret_cast(ptr); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldg(uint16_t& dst, void const* ptr) { + dst = *reinterpret_cast(ptr); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldg(uint32_t& dst, void const* ptr) { + dst = *reinterpret_cast(ptr); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldg(uint2& dst, void const* ptr) { + dst = *reinterpret_cast(ptr); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldg(uint4& dst, void const* ptr) { + dst = *reinterpret_cast(ptr); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Ldg_functor { + // Ctor. + inline __device__ Ldg_functor(Data_type (&fetch)[N], void const* (&ptrs)[N]) + : fetch_(fetch), ptrs_(ptrs) {} + + // Clear the element. + inline __device__ void clear(int ii) { fmha::clear(fetch_[ii]); } + + // Trigger the loads. + inline __device__ void ldgsts(int ii, bool p) { + if (p) { + ldg(fetch_[ii], ptrs_[ii]); + } + } + + // The fetch registers. + Data_type (&fetch_)[N]; + // The pointers. + void const* (&ptrs_)[N]; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void ldg_(Data_type (&fetch)[N], void const* (&ptrs)[N], uint32_t (&preds)[M]) { + Ldg_functor fct(fetch, ptrs); + ldgsts_(fct, preds); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void ldg(uint8_t (&fetch)[N], void const* (&ptrs)[N], uint32_t (&preds)[M]) { + ldg_(fetch, ptrs, preds); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void ldg(uint16_t (&fetch)[N], void const* (&ptrs)[N], uint32_t (&preds)[M]) { + ldg_(fetch, ptrs, preds); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void ldg(uint32_t (&fetch)[N], void const* (&ptrs)[N], uint32_t (&preds)[M]) { + ldg_(fetch, ptrs, preds); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void ldg(uint2 (&fetch)[N], void const* (&ptrs)[N], uint32_t (&preds)[M]) { + ldg_(fetch, ptrs, preds); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void ldg(uint4 (&fetch)[N], void const* (&ptrs)[N], uint32_t (&preds)[M]) { + ldg_(fetch, ptrs, preds); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void ldgdepbar() { + if (USE_LDGSTS) { + asm volatile("cp.async.commit_group;\n" ::); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void depbar_() { + if (USE_LDGSTS) { + asm volatile("cp.async.wait_group %0;\n" ::"n"(COUNT)); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void depbar() { + if (USE_LDGSTS) { + int const VALUE = Max::VALUE; + asm volatile("cp.async.wait_group %0;\n" ::"n"(VALUE)); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldgsts128(uint32_t dst, void const* src, bool p = true) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + uint32_t m = p ? 16u : 0u; + asm volatile("cp.async.cg.shared.global [%0], [%1], 16, %2;\n" ::"r"(dst), "l"(src), "r"(m)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Ldgsts_functor { + // Ctor. + inline __device__ Ldgsts_functor(uint32_t (&smem_ptrs)[N], void const* (&gmem_ptrs)[N]) + : smem_ptrs_(smem_ptrs), gmem_ptrs_(gmem_ptrs) {} + + // Does nothing. + inline __device__ void clear(int ii) {} + + // Trigger the load-store instruction. + inline __device__ void ldgsts(int ii, bool p) { ldgsts128(smem_ptrs_[ii], gmem_ptrs_[ii], p); } + + // The shared memory pointers. + uint32_t (&smem_ptrs_)[N]; + // The global memory pointers. + void const* (&gmem_ptrs_)[N]; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void ldgsts(uint32_t (&dst)[N], void const* (&src)[N], uint32_t (&preds)[M]) { + Ldgsts_functor fct(dst, src); + ldgsts_(fct, preds); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// L D S +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void lds(uint16_t& dst, uint32_t ptr) { + asm volatile("ld.shared.b16 %0, [%1];\n" : "=h"(dst) : "r"(ptr)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void lds(uint32_t& dst, uint32_t ptr) { + asm volatile("ld.shared.b32 %0, [%1];\n" : "=r"(dst) : "r"(ptr)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void lds(uint2& dst, uint32_t ptr) { + asm volatile("ld.shared.v2.b32 {%0, %1}, [%2];\n" : "=r"(dst.x), "=r"(dst.y) : "r"(ptr)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void lds(uint4& dst, uint32_t ptr) { + asm volatile("ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w) + : "r"(ptr)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// L D S M +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldsm(uint32_t& dst, uint32_t ptr) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 + asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n" : "=r"(dst) : "r"(ptr)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldsmt(uint32_t& dst, uint32_t ptr) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 + asm volatile("ldmatrix.sync.aligned.m8n8.x1.trans.shared.b16 {%0}, [%1];\n" + : "=r"(dst) + : "r"(ptr)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldsm(uint2& dst, uint32_t ptr) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 + asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0, %1}, [%2];\n" + : "=r"(dst.x), "=r"(dst.y) + : "r"(ptr)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldsmt(uint2& dst, uint32_t ptr) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 + asm volatile("ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16 {%0, %1}, [%2];\n" + : "=r"(dst.x), "=r"(dst.y) + : "r"(ptr)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldsm(uint4& dst, uint32_t ptr) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w) + : "r"(ptr)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldsmt(uint4& dst, uint32_t ptr) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 + asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w) + : "r"(ptr)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// S T S M +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void stsm(uint32_t ptr, uint32_t const& src) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + asm volatile("stmatrix.sync.aligned.m8n8.x1.shared.b16 [%0], {%1};\n" ::"r"(ptr), "r"(src)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void stsmt(uint32_t ptr, uint32_t const& src) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + asm volatile("stmatrix.sync.aligned.m8n8.x1.trans.shared.b16 [%0], {%1};\n" ::"r"(ptr), "r"(src)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void stsm(uint32_t ptr, uint2 const& src) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + asm volatile("stmatrix.sync.aligned.m8n8.x2.shared.b16 [%0], {%1, %2};\n" ::"r"(ptr), "r"(src.x), + "r"(src.y)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void stsmt(uint32_t ptr, uint2 const& src) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + asm volatile("stmatrix.sync.aligned.m8n8.x2.trans.shared.b16 [%0], {%1, %2};\n" ::"r"(ptr), + "r"(src.x), "r"(src.y)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void stsm(uint32_t ptr, uint4 const& src) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + asm volatile("stmatrix.sync.aligned.m8n8.x4.shared.b16 [%0], {%1, %2, %3, %4};\n" ::"r"(ptr), + "r"(src.x), "r"(src.y), "r"(src.z), "r"(src.w)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void stsmt(uint32_t ptr, uint4 const& src) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + asm volatile( + "stmatrix.sync.aligned.m8n8.x4.trans.shared.b16 [%0], {%1, %2, %3, %4};\n" ::"r"(ptr), + "r"(src.x), "r"(src.y), "r"(src.z), "r"(src.w)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// S T G +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void stg(void* ptr, float val) { *reinterpret_cast(ptr) = val; } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void stg(void* ptr, uint8_t val) { *reinterpret_cast(ptr) = val; } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void stg(void* ptr, uint16_t val) { *reinterpret_cast(ptr) = val; } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void stg(void* ptr, uint32_t val) { *reinterpret_cast(ptr) = val; } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void stg(void* ptr, uint2 val) { *reinterpret_cast(ptr) = val; } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void stg(void* ptr, uint4 val) { *reinterpret_cast(ptr) = val; } + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// S T S +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void sts(uint32_t ptr, uint16_t val) { + asm volatile("st.shared.b16 [%0], %1;\n" : : "r"(ptr), "h"(val)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void sts(uint32_t ptr, uint32_t val) { + asm volatile("st.shared.b32 [%0], %1;\n" : : "r"(ptr), "r"(val)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void sts(uint32_t ptr, uint2 val) { + asm volatile("st.shared.v2.b32 [%0], {%1, %2};\n" : : "r"(ptr), "r"(val.x), "r"(val.y)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void sts(uint32_t ptr, uint4 val) { + asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};\n" + : + : "r"(ptr), "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void sts_(uint32_t (&ptrs)[N], Data_type const (&data)[N]) { +#pragma unroll + for (int ii = 0; ii < N; ++ii) { + sts(ptrs[ii], data[ii]); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void sts(uint32_t (&ptrs)[N], uint16_t const (&data)[N]) { + sts_(ptrs, data); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void sts(uint32_t (&ptrs)[N], uint32_t const (&data)[N]) { + sts_(ptrs, data); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void sts(uint32_t (&ptrs)[N], uint2 const (&data)[N]) { + sts_(ptrs, data); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void sts(uint32_t (&ptrs)[N], uint4 const (&data)[N]) { + sts_(ptrs, data); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +#define __HALF2_TO_UI(var) *(reinterpret_cast(&(var))) +#define __HALF2_TO_CUI(var) *(reinterpret_cast(&(var))) + +static __device__ __inline__ void atomicAdd_half2(half2* const address, const half2 val) { + asm volatile("{ red.global.add.noftz.f16x2 [%0],%1; }\n" ::"l"(address), "r"(__HALF2_TO_CUI(val)) + : "memory"); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +static inline __device__ uint32_t float4_to_char4(float x, float y, float z, float w) { +#if defined(USE_F2I_EMULATION_TRICK) + // Make sure the float is in the proper range. + float cx, cy, cz, cw; + if (CAN_BE_NEGATIVE) { + cx = fmha::clamp(x, -128.f, 127.f); + cy = fmha::clamp(y, -128.f, 127.f); + cz = fmha::clamp(z, -128.f, 127.f); + cw = fmha::clamp(w, -128.f, 127.f); + } else { + cx = fminf(x, 127.f); + cy = fminf(y, 127.f); + cz = fminf(z, 127.f); + cw = fminf(w, 127.f); + } + + // Re-add the magic number. + cx += FP32_I2F_MAGIC_NUMBER; + cy += FP32_I2F_MAGIC_NUMBER; + cz += FP32_I2F_MAGIC_NUMBER; + cw += FP32_I2F_MAGIC_NUMBER; + + // We need unsigned ints... + uint32_t a = reinterpret_cast(cx); + uint32_t b = reinterpret_cast(cy); + uint32_t c = reinterpret_cast(cz); + uint32_t d = reinterpret_cast(cw); + + // Pack the numbers. + uint32_t dst; + asm volatile("prmt.b32 %0, %1, %2, 0x0040;\n" : "=r"(dst) : "r"(a), "r"(b)); + asm volatile("prmt.b32 %0, %0, %1, 0x0410;\n" : "+r"(dst) : "r"(c)); + asm volatile("prmt.b32 %0, %0, %1, 0x4210;\n" : "+r"(dst) : "r"(d)); + return dst; +#else + uint32_t a; + asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(a) : "f"(x)); + uint32_t b; + asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(b) : "f"(y)); + uint32_t c; + asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(c) : "f"(z)); + uint32_t d; + asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(d) : "f"(w)); + + uint32_t dst; + asm volatile("cvt.pack.sat.s8.s32.b32 %0, %1, %2, 0;\n" : "=r"(dst) : "r"(d), "r"(c)); + asm volatile("cvt.pack.sat.s8.s32.b32 %0, %1, %2, %0;\n" : "+r"(dst) : "r"(b), "r"(a)); + return dst; +#endif // defined(USE_F2I_EMULATION_TRICK) +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ void swizzle_rows(uint32_t& a, uint32_t& b, uint32_t c, uint32_t d) { + asm volatile("prmt.b32 %0, %1, %2, 0x6420;\n" : "=r"(a) : "r"(c), "r"(d)); + asm volatile("prmt.b32 %0, %1, %2, 0x7531;\n" : "=r"(b) : "r"(c), "r"(d)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldsm_with_lds(uint2& data, uint32_t smem) { + int lane = threadIdx.x % 32; + data = {0, 0}; + uint4 v = {0, 0, 0, 0}; + uint32_t* a = reinterpret_cast(&v); + if (lane < 16) { + fmha::lds(v, smem); + } + int src_row = lane / 4; + int src_col = lane % 4; + for (int it = 0; it < 4; it++) { + uint32_t val = a[it]; + uint32_t x = __shfl_sync(uint32_t(-1), val, src_row); + __syncwarp(); + uint32_t y = __shfl_sync(uint32_t(-1), val, src_row + 8); + __syncwarp(); + if (it == src_col) { + data.x = x; + data.y = y; + } + } +} + +inline __device__ void ldsmt_with_lds(uint2& data, uint32_t smem) { + int lane = threadIdx.x % 32; + + uint4 tmp16{0, 0, 0, 0}; // 16B + + if (lane < 16) { + fmha::lds(tmp16, smem); + } + + uint16_t* tmp16c = reinterpret_cast(&tmp16); // 8x2B: we move pairs + + uint16_t* t = reinterpret_cast(&data); // 4x2B + + int const src_col = lane / 4; // 0 - 7 + int const src_row = (lane % 4) * 2; + +// we have to shuffle the values to distribute them in the warp +#pragma unroll + for (int it = 0; it < 8; it++) { + uint16_t val, x, y; + val = tmp16c[it]; + x = __shfl_sync(uint32_t(-1), val, src_row + 0); + __syncwarp(); + y = __shfl_sync(uint32_t(-1), val, src_row + 1); + __syncwarp(); + + if (src_col == it) { + t[0] = x; + t[1] = y; + } + val = tmp16c[it]; + x = __shfl_sync(uint32_t(-1), val, src_row + 8); + __syncwarp(); + y = __shfl_sync(uint32_t(-1), val, src_row + 9); + __syncwarp(); + + if (src_col == it) { + t[2] = x; + t[3] = y; + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MaxOp { + __device__ inline T operator()(T const& x, T const& y) { return x > y ? x : y; } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SumOp { + __device__ inline T operator()(T const& x, T const& y) { return x + y; } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + + template + static __device__ inline T run(T x, Operator& op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Allreduce<2> { + template + static __device__ inline T run(T x, Operator& op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ inline void quad_reduce(float (&dst)[M], float (&src)[M], Operator& op) { +#pragma unroll + for (int mi = 0; mi < M; mi++) { + dst[mi] = src[mi]; + dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 2)); + dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 1)); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ inline void quad_reduce(float (&dst)[M], float2 (&src)[M], Operator& op) { + float tmp[M]; +#pragma unroll + for (int mi = 0; mi < M; mi++) { + tmp[mi] = op(src[mi].x, src[mi].y); + } + quad_reduce(dst, tmp, op); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ inline void quad_allreduce(float (&dst)[M], float (&src)[M], Operator& op) { +#pragma unroll + for (int mi = 0; mi < M; mi++) { + dst[mi] = src[mi]; + dst[mi] = Allreduce<4>::run(dst[mi], op); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ inline void quad_allreduce(float (&dst)[M], float2 (&src)[M], Operator& op) { + float tmp[M]; +#pragma unroll + for (int mi = 0; mi < M; mi++) { + tmp[mi] = op(src[mi].x, src[mi].y); + } + quad_allreduce(dst, tmp, op); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint32_t elect_one_sync() { + uint32_t pred = 0; +#if __CUDA_ARCH__ >= 900 +#if !defined(__CUDACC_RTC__) + uint32_t laneid = 0; + asm volatile( + "\n\ + {\n\ + .reg .b32 %rx;\n\ + .reg .pred %px;\n\ + elect.one.sync %rx|%px, %2;\n\ + @%px mov.s32 %1, 1;\n\ + mov.s32 %0, %rx;\n\ + }\n" + : "+r"(laneid), "+r"(pred) + : "r"(0xFFFFFFFF)); +#else + pred = threadIdx.x == 0; +#endif +#endif + return pred; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint16_t float2_to_e4m3x2(float x, float y) { +#if defined(__CUDA_ARCH__) && \ + ((__CUDA_ARCH__ == 890 && defined(FMHA_ENABLE_SM89_QMMA)) || (__CUDA_ARCH__ >= 900)) + uint16_t res; + asm volatile("cvt.rn.e4m3x2.f32.satfinite %0, %2, %1;" : "=h"(res) : "f"(x), "f"(y)); + return res; +#else + assert(false); + return 0; +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint32_t float4_to_e4m3x4(float x, float y, float z, float w) { +#if defined(__CUDA_ARCH__) && \ + ((__CUDA_ARCH__ == 890 && defined(FMHA_ENABLE_SM89_QMMA)) || (__CUDA_ARCH__ >= 900)) + uint32_t res; + asm volatile( + "{\n" + ".reg .b16 lo;\n" + ".reg .b16 hi;\n" + "cvt.rn.e4m3x2.f32.satfinite lo, %2, %1;\n" + "cvt.rn.e4m3x2.f32.satfinite hi, %4, %3;\n" + "mov.b32 %0, {lo, hi};\n" + "}" + : "=r"(res) + : "f"(x), "f"(y), "f"(z), "f"(w)); + return res; +#else + assert(false); + return 0; +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint32_t float4_to_e5m2x4(float x, float y, float z, float w) { +#if defined(__CUDA_ARCH__) && \ + ((__CUDA_ARCH__ == 890 && defined(FMHA_ENABLE_SM89_QMMA)) || (__CUDA_ARCH__ >= 900)) + uint32_t res; + asm volatile( + "{\n" + ".reg .b16 lo;\n" + ".reg .b16 hi;\n" + "cvt.rn.e5m2x2.f32.satfinite lo, %2, %1;\n" + "cvt.rn.e5m2x2.f32.satfinite hi, %4, %3;\n" + "mov.b32 %0, {lo, hi};\n" + "}" + : "=r"(res) + : "f"(x), "f"(y), "f"(z), "f"(w)); + return res; +#else + assert(false); + return 0; +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint32_t half4_to_e4m3x4(uint32_t const h2_0, uint32_t const h2_1) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)) + uint32_t res; + asm volatile( + "{\n" + ".reg .b16 lo, hi;\n" + "cvt.satfinite.rn.e4m3x2.f16x2 lo, %1;\n" + "cvt.satfinite.rn.e4m3x2.f16x2 hi, %2;\n" + "mov.b32 %0, {lo, hi};\n" + "}\n" + : "=r"(res) + : "r"(h2_0), "r"(h2_1)); + return res; +#else + assert(false); + return 0; +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint32_t half4_to_e5m2x4(uint32_t const h2_0, uint32_t const h2_1) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)) + uint32_t res; + asm volatile( + "{\n" + ".reg .b16 lo, hi;\n" + "cvt.satfinite.rn.e5m2x2.f16x2 lo, %1;\n" + "cvt.satfinite.rn.e5m2x2.f16x2 hi, %2;\n" + "mov.b32 %0, {lo, hi};\n" + "}\n" + : "=r"(res) + : "r"(h2_0), "r"(h2_1)); + return res; +#else + assert(false); + return 0; +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Helpers to pack float4 into a destination register with 4 8bit values +template +inline __device__ uint32_t float4_to_8bitx4(float const x, float const y, float const z, + float const w) { + return float4_to_char4(x, y, z, w); +}; + +template <> +inline __device__ uint32_t float4_to_8bitx4(float const x, float const y, float const z, + float const w) { + return float4_to_e4m3x4(x, y, z, w); +}; + +template <> +inline __device__ uint32_t float4_to_8bitx4(float const x, float const y, float const z, + float const w) { + return float4_to_e5m2x4(x, y, z, w); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ uint32_t half4_to_fp8x4(uint32_t const h2_0, uint32_t const h2_1); + +template <> +inline __device__ uint32_t half4_to_fp8x4(uint32_t const h2_0, uint32_t const h2_1) { + return half4_to_e4m3x4(h2_0, h2_1); +} + +template <> +inline __device__ uint32_t half4_to_fp8x4(uint32_t const h2_0, uint32_t const h2_1) { + return half4_to_e5m2x4(h2_0, h2_1); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ uint32_t float4_to_fp8x4(float const, float const, float const, float const); + +template <> +inline __device__ uint32_t float4_to_fp8x4(float const x, float const y, + float const z, float const w) { + return float4_to_e4m3x4(x, y, z, w); +} + +template <> +inline __device__ uint32_t float4_to_fp8x4(float const x, float const y, + float const z, float const w) { + return float4_to_e5m2x4(x, y, z, w); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void fence_view_async_shared() { + // Issue a shared memory fence for async operations (FENCE.VIEW.ASYNC.S) + // only compiles on sm90+ + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + asm volatile("fence.proxy.async.shared::cta;\n"); +#else + assert(false); +#endif +} + +inline __device__ void fence_view_async_global() { + // Issue a global memory fence for async operations (FENCE.VIEW.ASYNC.G) + // only compiles on sm90+ + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + asm volatile("fence.proxy.async.global::cta;\n"); +#else + assert(false); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ char* align_1024(char* ptr) { + uint64_t address_bit = reinterpret_cast(ptr); + uint64_t offset = address_bit % 1024; + if (offset == 0) { + return ptr; + } else { + return ptr + (1024 - offset); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ float atomicMaxFloat(float* addr, float value) { + float old; + old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) + : __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); + return old; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ float atomicMaxFloatPos_(float* addr, float value) { + // VALUE MUST BE POSITIVE! USED ONLY FOR INTERNAL AMAX REDUCTION. + float old = __int_as_float(atomicMax((int*)addr, __float_as_int(value))); + return old; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ float max3Pos_(float const a, float const b, float const c) { + // VALUE MUST BE POSITIVE! USED ONLY FOR INTERNAL AMAX REDUCTION. + float res; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + int32_t a_ = reinterpret_cast(a); + int32_t b_ = reinterpret_cast(b); + int32_t c_ = reinterpret_cast(c); + int32_t tmp; + asm volatile("max.s16x2 %0, %1, %2;\n" : "=r"(tmp) : "r"(a_), "r"(b_)); + asm volatile("max.s16x2 %0, %0, %1;\n" : "+r"(tmp) : "r"(tmp), "r"(c_)); + res = reinterpret_cast(tmp); +#else + res = fmaxf(a, fmaxf(b, c)); +#endif + return res; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Fast approximate tanh. +static inline __device__ float __tanhf(float x) { +#if (__CUDA_ARCH__ >= 750) + float r = x; + asm("tanh.approx.f32 %0, %0;" : "+f"(r)); + return r; +#else + return tanhf(x); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/warpspec/circular_buffer.h b/csrc/fmha_v2/fmha/warpspec/circular_buffer.h new file mode 100644 index 0000000000..903319490a --- /dev/null +++ b/csrc/fmha_v2/fmha/warpspec/circular_buffer.h @@ -0,0 +1,399 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include +#include +#include + +#pragma once + +namespace fmha { +namespace ws { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/* Shared storage for barriers needed by both producer and consumer */ +template +struct CircularBufferBarriers { + __align__(8) uint64_t entryProducedBarriers[DEPTH]; + __align__(8) uint64_t entryConsumedBarriers[DEPTH]; + + CircularBufferBarriers() = default; + // CircularBufferBarriers must live in __shared__ -- cannot copy + CircularBufferBarriers(CircularBufferBarriers const& other) = delete; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/* Producer class */ +template +class CircularBufferWriter { + protected: + uint32_t _wptr; + uint32_t _phase; + fmha::Arrive_wait _entryConsumedBarriers; + fmha::Arrive_wait _entryProducedBarriers; + + public: + inline __device__ CircularBufferWriter(CircularBufferBarriers* barriers) + : _entryProducedBarriers(barriers->entryProducedBarriers), + _entryConsumedBarriers(barriers->entryConsumedBarriers), + _wptr(0), + _phase(0xffffffff) {} + + inline __device__ int ptr() { return _wptr; } + + // Return the equivalent read phase. + inline __device__ int phase() { return _phase ^ 0xffffffff; } + + /* Reserve space in the buffer for TMA */ + inline __device__ int tmaReserve(int tid0, int transactioncnt) { + int ptr = threadReserve(); + _entryProducedBarriers.bar_arrive_set_transactioncnt(ptr, transactioncnt, tid0); + return ptr; + } + + /* Reserve space in the buffer for producer threads */ + inline __device__ int threadReserve() { + wait(); + return advance(); + } + + inline __device__ int advance() { + int rval = _wptr; + _phase ^= (1 << _wptr); + _wptr += 1; + if (_wptr >= DEPTH) { + _wptr = 0; + } + return rval; + } + + /* Wait for space to become available in the buffer */ + inline __device__ int wait() { + int ready = _entryConsumedBarriers.bar_peek(_wptr, (_phase >> _wptr) & 1); + if (!ready) _entryConsumedBarriers.bar_wait(_wptr, (_phase >> _wptr) & 1); + return _wptr; + } + + /* Signal that data is ready */ + inline __device__ void threadCommit(int tid0, int id) { + if (tid0) { + _entryProducedBarriers.bar_arrive_normal(id); + } + } + + /* Get the barrier address, needed by TMA */ + inline __device__ uint64_t* barrier_ptr(int id) { + return _entryProducedBarriers.get_bar_addr(id); + } + + inline __device__ void setPtr(int ptr) { _wptr = ptr; } + + inline __device__ void setPhase(int phase) { _phase = phase; } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/* Consumer class */ +template +class CircularBufferReader { + private: + uint32_t _rptr; + uint32_t _phase; + + public: + fmha::Arrive_wait _entryProducedBarriers; + fmha::Arrive_wait _entryConsumedBarriers; + + inline __device__ CircularBufferReader(CircularBufferBarriers* barriers) + : _entryProducedBarriers(barriers->entryProducedBarriers), + _entryConsumedBarriers(barriers->entryConsumedBarriers), + _rptr(0), + _phase(0) {} + + inline __device__ void setProducerCta(int cta_id) { + _entryConsumedBarriers.set_bar_base_dsmem(cta_id); + } + + /* Peek at the head */ + inline __device__ int peek() { + return _entryProducedBarriers.bar_peek(_rptr, (_phase >> _rptr) & 1); + } + + /* Wait for the head to be ready */ + inline __device__ int wait() { + _entryProducedBarriers.bar_wait(_rptr, (_phase >> _rptr) & 1); + return _rptr; + } + + /* Advance the head pointer */ + inline __device__ void advance() { + _phase ^= (1 << _rptr); + _rptr += 1; + if (_rptr >= DEPTH) { + _rptr = 0; + } + } + + inline __device__ int ptr() { return _rptr; } + + inline __device__ uint32_t phase() { return _phase; } + + /* Indicate consumption of data at specified pointer. + The producer is now free to overwrite it + */ + inline __device__ void complete(int tid0, int ptr) { + if (tid0) { + if (CGA_SIZE > 1) { + _entryConsumedBarriers.bar_arrive_dsmem(ptr); + } else { + _entryConsumedBarriers.bar_arrive_normal(ptr); + } + } + } + + /* Simplification of complete and advance for cases + where they don't need to be reordered/separated for performance + */ + inline __device__ void pop(int tid0) { + complete(tid0, _rptr); + advance(); + } + + /* Overrides for pointer and phase. Used for shared buffers */ + inline __device__ void setPtr(int ptr) { _rptr = ptr; } + + inline __device__ void setPhase(uint32_t phase) { _phase = phase; } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class CircularBuffer { + protected: + CircularBufferBarriers _barriers; + + public: + inline __device__ void init(int tid0, int producer_thread_count, int consumer_thread_count) { + if (tid0) { + for (int i = 0; i < DEPTH; i++) { + fmha::bar_create(&_barriers.entryProducedBarriers[i], producer_thread_count); + fmha::bar_create(&_barriers.entryConsumedBarriers[i], consumer_thread_count); + } + } + } + + using Reader = CircularBufferReader; + using Writer = CircularBufferWriter; + + inline __device__ Reader createReader() { return Reader(&_barriers); } + + inline __device__ Writer createWriter() { return Writer(&_barriers); } + + inline __device__ int depth() { return DEPTH; } + + CircularBuffer() = default; + // CircularBuffer must live in __shared__ -- cannot copy + CircularBuffer(CircularBuffer const& other) = delete; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class CircularBufferWithDataReader : public CircularBufferReader { + private: + T* _data; + + public: + inline __device__ CircularBufferWithDataReader(CircularBufferBarriers* barriers, T* data) + : CircularBufferReader(barriers), _data(data) {} + + inline __device__ T read() { return _data[this->ptr()]; } + + inline __device__ T pop(int tid0, bool read_data = true) { + T rval; + int ready = this->peek(); + if (!ready) this->wait(); + if (read_data) { + rval = read(); + fmha::fence_view_async_shared(); + } + this->complete(tid0, this->ptr()); + this->advance(); + return rval; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class CircularBufferWithDataWriter : public CircularBufferWriter { + private: + T* _data; + + public: + inline __device__ CircularBufferWithDataWriter(CircularBufferBarriers* barriers, T* data) + : CircularBufferWriter(barriers), _data(data) {} + + inline __device__ void write(int ptr, T const& wrdat) { _data[ptr] = wrdat; } + + inline __device__ int push(int tid0, T const& wrdat, bool writeData = true, + uint32_t transactioncnt = 0) { + int ptr = this->threadReserve(); + if (tid0 && writeData) { + write(ptr, wrdat); + __threadfence_block(); + } + if (transactioncnt == 0) + this->threadCommit(tid0, ptr); + else + this->_entryProducedBarriers.bar_arrive_set_transactioncnt(ptr, transactioncnt, tid0); + return ptr; + } + + template + inline __device__ int push_with_sync(int tid0, T const& wrdat, bool writeData = true, + uint32_t transactioncnt = 0) { + int ptr = this->threadReserve(); + named_barrier_wait(SYNC_BAR, SYNC_THREADS); + if (tid0 && writeData) { + write(ptr, wrdat); + __threadfence_block(); + } + if (transactioncnt == 0) + this->threadCommit(tid0, ptr); + else + this->_entryProducedBarriers.bar_arrive_set_transactioncnt(ptr, transactioncnt, tid0); + return ptr; + } + + inline __device__ void broadcast(T const& wrdat) { + int offset = this->threadReserve(); + for (int i = 0; i < CGA_SIZE; i++) { + push_to_cta(wrdat, i, offset); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class CircularBufferWithData : public CircularBuffer { + private: + T _data[DEPTH]; + + public: + inline __device__ T* data() { return _data; } + + using Reader = CircularBufferWithDataReader; + using Writer = CircularBufferWithDataWriter; + + inline __device__ Reader createReader() { return Reader(&this->_barriers, _data); } + + inline __device__ Writer createWriter() { return Writer(&this->_barriers, _data); } + + CircularBufferWithData() = default; + // Must live in __shared__ -- cannot copy + CircularBufferWithData(CircularBufferWithData const& other) = delete; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct OrderedMutex { + uint64_t barriers[2]; + + inline __device__ void init(int tid0, int threads0, int threads1) { + if (tid0) { + fmha::bar_create(&barriers[0], threads0); + fmha::bar_create(&barriers[1], threads1); + } + } +}; + +class OrderedMutexAccessor { + private: + int _phase; + int _id; + int _barrier_id; + + fmha::Arrive_wait _barriers; + + public: + inline __device__ OrderedMutexAccessor(OrderedMutex& m, int id, int barrier_id) + : _phase(0), _id(id), _barriers(m.barriers), _barrier_id(barrier_id) {} + + inline __device__ void arrive() { _barriers.bar_arrive_normal(_id); } + + inline __device__ void wait() { + int ready = _barriers.bar_peek(_id ^ 1, _phase); + if (!ready) { + _barriers.bar_wait(_id ^ 1, _phase); + } + _phase ^= 1; + } + + inline __device__ void named_bar_arrive() { + // ... + // Softmax ends + // Make sure barrier is not moving around + if (_id == 0) { + named_barrier_wait(_barrier_id, 256); + } + } + + inline __device__ void named_bar_wait() { + // Make sure barrier is not moving around + if (_id == 1) { + named_barrier_wait(_barrier_id, 256); + } + // Softmax starts + // ... + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct ComputeGroupBarrier { + uint64_t barrier; + + inline __device__ void init(int tid0, int threads) { + if (tid0) { + fmha::bar_create(&barrier, threads); + } + } +}; + +class ComputeGroupBarrierAccessor { + private: + int _phase; + fmha::Arrive_wait _barrier; + + public: + inline __device__ ComputeGroupBarrierAccessor(ComputeGroupBarrier& m) + : _phase(0), _barrier(&m.barrier) {} + + inline __device__ void arrive() { _barrier.bar_arrive_normal(0); } + + inline __device__ void wait() { + int ready = _barrier.bar_peek(0, _phase); + if (!ready) { + _barrier.bar_wait(0, _phase); + } + _phase ^= 1; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace ws +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/warpspec/compute.h b/csrc/fmha_v2/fmha/warpspec/compute.h new file mode 100644 index 0000000000..9aae70b2e7 --- /dev/null +++ b/csrc/fmha_v2/fmha/warpspec/compute.h @@ -0,0 +1,606 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include "fmha/alibi_params.h" +#include "fmha/hopper/fragment.h" +#include "fmha/hopper/utils_warpgroup.h" +#include "fmha/softmax.h" +#include "fmha/warpspec/circular_buffer.h" +#include "fmha/warpspec/dma.h" +#include "fmha/warpspec/epilogue.h" + +namespace fmha { +namespace ws { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // Template instruction traits to specialize structs + template class Instruction_traits, + // Kernel Traits + typename Kernel_traits> +struct Compute { + // The shared struct. + using Shared = typename Kernel_traits::Shared; + + // The q, or kv tile reader. + using Circular_buffer_q_reader = typename Kernel_traits::Circular_buffer_q_reader; + using Circular_buffer_kv_reader = typename Kernel_traits::Circular_buffer_kv_reader; + + // The instruction traits for BMM1. + using Traits_p = typename Kernel_traits::Traits_p; + // The instruction traits for BMM2. + using Traits_o = typename Kernel_traits::Traits_o; + + // The CTA description for BMM1. + using Cta_tile_p = typename Kernel_traits::Cta_tile_p; + // The CTA description for BMM2. + using Cta_tile_o = typename Kernel_traits::Cta_tile_o; + + // The Q shared memory tile. + using Smem_tile_q = typename Kernel_traits::Smem_tile_q; + // The K shared memory tile. + using Smem_tile_k = typename Kernel_traits::Smem_tile_k; + // The V shared memory tile. + using Smem_tile_v = typename Kernel_traits::Smem_tile_v; + + // The GMMA compute tile for BMM1. + using Compute_tile_p = typename Kernel_traits::Compute_tile_p; + // The GMMA compute tile for BMM2. + using Compute_tile_o = typename Kernel_traits::Compute_tile_o; + + // The MMA tile for the BMM1. + using Mma_tile_p = typename Kernel_traits::Mma_tile_p; + // The MMA tile for the BMM2. + using Mma_tile_o = typename Kernel_traits::Mma_tile_o; + + // The fragment of BMM1 output. + using Fragment_p = typename Compute_tile_o::Fragment; + + // The global memory tile for storing BMM2 output. + using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o; + + // Softmax + using Softmax = Softmax; + + // BMM2 epilogue + using Tile_o_epilogue = Tile_o_epilogue; + + // The step size of Q loop. + enum { STEP_Q = Kernel_traits::STEP_Q }; + + // The step size of KV loop. + enum { STEP_KV = Kernel_traits::STEP_KV }; + + // The number of compute groups (currently fixed at 2). + enum { NUM_COMPUTE_GROUPS = Kernel_traits::NUM_COMPUTE_GROUPS }; + + // Whether we skip those masked tiles when causal mask is enabled ? + enum { SKIP_CAUSAL_MASK_TILES = Kernel_traits::CAUSAL_MASK && !Kernel_traits::USE_CUSTOM_MASK }; + + // Whether we attend to the specific sliding window or chunk ? + enum { SLIDING_OR_CHUNKED_ATTENTION = Kernel_traits::SLIDING_OR_CHUNKED_ATTENTION }; + + // Are we applying alibi bias (drop FMA optimizations for accuracy reasons). + enum { APPLY_ALIBI = Kernel_traits::APPLY_ALIBI }; + + // Do we use custom mask input ? + enum { USE_CUSTOM_MASK = Kernel_traits::USE_CUSTOM_MASK }; + + // Do we always need to apply the mask ? + enum { ALWAYS_APPLY_MASK = APPLY_ALIBI || USE_CUSTOM_MASK }; + + // Enable mutex for overlapping mma and softmax instructions. + enum { ENABLE_MUTEX = Kernel_traits::ENABLE_MUTEX }; + + // The head_dimension groups. + enum { D_GROUPS = Kernel_traits::D_GROUPS }; + + // The MMA_K groups (corresponding to head_dimension groups). + enum { BMM1_MMAS_K_GROUPS = Kernel_traits::D_GROUPS }; + + // The number of MMAS_K for each head_dimension group. + enum { BMM1_MMAS_K_PER_GROUP = Mma_tile_p::MMAS_K / BMM1_MMAS_K_GROUPS }; + + // The MMA_K groups (corresponding to kv_step groups). + enum { BMM2_MMAS_K_GROUPS = Kernel_traits::BMM2_K_GROUPS }; + + // The number of MMAS_K for each head_dimension group. + enum { BMM2_MMAS_K_PER_GROUP = Mma_tile_o::MMAS_K / BMM2_MMAS_K_GROUPS }; + + // The tile size of V after head_dimension split. + enum { TILE_SIZE_V_PER_D_GROUP = STEP_KV * Kernel_traits::D_PER_GROUP }; + + enum { TILE_SIZE_V = STEP_KV * Kernel_traits::DV }; + + enum { TILE_BYTES_V_PER_D_GROUP = STEP_KV * Kernel_traits::D_BYTES_PER_GROUP }; + + enum { TILE_BYTES_V_PER_K_GROUP = BMM2_MMAS_K_PER_GROUP * Kernel_traits::D_BYTES_PER_GROUP }; + + // Named barrier for inter-warpgroup sync + enum { SYNC_BARRIER = Kernel_traits::MMA_SYNC_BARRIER_ID }; + + // Whether Q and KV is in separate buffer, which means we need to consider different Q and KV + // lengths. + enum { SEPARATE_Q_KV_BUFFER = Kernel_traits::SEPARATE_Q_KV_BUFFER }; + + enum { SAGE_BLOCK_SIZE_Q = Kernel_traits::SAGE_BLOCK_SIZE_Q }; + + // sanitize 0 to -1, avoid DIV BY ZERO below + enum { + SAGE_BLOCK_SIZE_K = Kernel_traits::SAGE_BLOCK_SIZE_K > 0 ? Kernel_traits::SAGE_BLOCK_SIZE_K : -1 + }; + + enum { + SAGE_BLOCK_SIZE_V = Kernel_traits::SAGE_BLOCK_SIZE_V > 0 ? Kernel_traits::SAGE_BLOCK_SIZE_V : -1 + }; + + // BLOCK_SIZE_Q should be multiply of STEP_Q (usually 64) so that q scale can be fused into + // scale_bmm1 + static_assert(SAGE_BLOCK_SIZE_Q < 0 || SAGE_BLOCK_SIZE_Q % STEP_Q == 0); + static_assert(SAGE_BLOCK_SIZE_K < 0 || SAGE_BLOCK_SIZE_K % 8 == 0); // 8 = columns of a gmma CORE + static_assert(SAGE_BLOCK_SIZE_V < 0 || + SAGE_BLOCK_SIZE_V % 32 == 0); // 32 = K dimension of a qgmma + + // SAGE_BLOCKS_PER_STEP_X is used to declare scale buffer like `float + // scales_k[SAGE_BLOCKS_PER_STEP_K];` if SAGE_BLOCKS_PER_STEP_X == 0, you will get `zero-sized + // variable is not allowed in device code` error from nvcc, so the minimal value have to be 1. But + // don't worry, unused local variables will be optimized out by compiler. + enum { SAGE_BLOCKS_PER_STEP_K = std::max(STEP_KV / SAGE_BLOCK_SIZE_K, 1) }; + + enum { SAGE_BLOCKS_PER_STEP_V = std::max(STEP_KV / SAGE_BLOCK_SIZE_V, 1) }; + +#define K_TILE_WAIT() \ + int ready_k = cbr_k.peek(); \ + if (!ready_k) { \ + cbr_k.wait(); \ + } + +#define KV_TILE_COMPLETE() \ + cbr_k.complete(tidx == 0, cbr_k.ptr()); \ + cbr_v.complete(tidx == 0, cbr_v.ptr()); \ + cbr_k.advance(); \ + cbr_v.advance(); + +#define COMPUTE_SINGLE_TILE(IS_FIRST_COL, APPLY_MASK) \ + compute_single_tile( \ + params, ctile_p, softmax, ctile_o, p_max, p_sum, tidx, actual_kv_seqlen, alibi_head_scale, \ + USE_CUSTOM_MASK ? (head_info.mask_sum_s + q_step_idx * STEP_Q + local_q_tile_offset) \ + : (q_step_idx * STEP_Q + head_info.q_tile_offset), \ + kv_step_idx * STEP_KV, sage_scale_row, cbr, cbr_v, mutex_accessor, \ + kv_step_idx == kv_idx_end - 1); + + //////////////////////////////////////////////////////////////////////////////////////////////// + + inline __device__ int div_up(int a, int b) { return (a + b - 1) / b; } + + //////////////////////////////////////////////////////////////////////////////////////////////// + + // Compute the kv_left_mask_end and kv_right_mask_start, where mask is applied when kv_idx < + // kv_left_mask_end or kv_idx >= kv_right_mask_start. + template + inline __device__ std::pair compute_kv_mask_start_end(Params const& params, + int const tile_offset_start, + int const tile_offset_end, + int const kv_idx_end) { + // The kv_left_mask_end is 0 by default. + int kv_left_mask_end = 0; + // The kv_right_mask_start is kv_idx_end - 1 by default, which means only the last kv tile is + // masked. + int kv_right_mask_start = kv_idx_end - 1; + + // Always apply mask is specified. + if constexpr (ALWAYS_APPLY_MASK) { + return std::make_pair(0, 0); + } + + // Is the chunked_attention used ? + bool is_chunked_attention = params.log2_chunked_attention_size > 0; + + // The left mask is needed when we attend to a specific sliding window or chunk. + if constexpr (SLIDING_OR_CHUNKED_ATTENTION) { + // The kv_left_mask_end is the start of the chunk. + kv_left_mask_end = + div_up(is_chunked_attention ? ((tile_offset_end >> params.log2_chunked_attention_size) + << params.log2_chunked_attention_size) + : (tile_offset_end + 1 - params.sliding_window_size), + STEP_KV); + } + + // The right mask is needed when causal mask (including sliding_window_attention or chunked + // attention) is used. + if constexpr (SKIP_CAUSAL_MASK_TILES) { + kv_right_mask_start = tile_offset_start / STEP_KV; + } + + // Return the kv_left_mask_end and kv_right_mask_start. + return std::make_pair(kv_left_mask_end, kv_right_mask_start); + } + + //////////////////////////////////////////////////////////////////////////////////////////////// + + template + inline __device__ void run(int warpgroup_id, int tidx, Shared* shared, Params const& params) { + auto head_tracker = shared->head_info_tracker[warpgroup_id].createReader(); + auto cbr = shared->tma_q_tracker[warpgroup_id].createReader(); + + auto cbr_k = shared->tma_k_tracker.createReader(); + auto cbr_v = shared->tma_v_tracker.createReader(); + + // Ctile_p initialize (relies on q_stage, kv_stage). + char* smem_q = reinterpret_cast(&shared->smem_q[warpgroup_id][0]); + char* smem_k = reinterpret_cast(&shared->smem_k[0]); + Compute_tile_p ctile_p(smem_q, smem_k); + + // Softmax + Softmax softmax(params, tidx); + + // Ctile_o initialize (relies on kv_stage). + uint32_t smem_v = __cvta_generic_to_shared(&shared->smem_v[0]); + Compute_tile_o ctile_o(0, smem_v); + + // Mutex between two compute groups. + OrderedMutexAccessor mutex_accessor(shared->compute_mutex, warpgroup_id, SYNC_BARRIER); + // Notify warpgroup 0 to execute HGMMA first (overlap HGMMA and Softmax Math Instructions). + if (ENABLE_MUTEX && warpgroup_id == 1 && Kernel_traits::ELEMENT_BYTES == 2) { + mutex_accessor.arrive(); + } + + // While loop for different heads. + while (true) { + typename Shared::Head_info head_info = head_tracker.pop(true); + + if (head_info.kv_steps == -1) { + break; + } + + int const kv_steps = head_info.kv_steps; + int const q_steps = head_info.q_steps; + int const local_q_tile_offset = head_info.local_q_tile_offset; + // The global q tile offset (based on past kv cache). + // Not used by custom mask input. + int const q_tile_offset = + SEPARATE_Q_KV_BUFFER ? head_info.q_tile_offset : head_info.local_q_tile_offset; + int const actual_q_seqlen = head_info.actual_seqlen; + // Contiguous QKV FMHA assumes q, and kv have the same sequence length. + int const actual_kv_seqlen = + SEPARATE_Q_KV_BUFFER ? head_info.actual_kv_seqlen : actual_q_seqlen; + + // Calculate the alibi head_scaling_factor. + float alibi_head_scale = APPLY_ALIBI ? get_alibi_head_scaling_factor( + head_info.bidh, params.alibi_params) + : 0.f; + // pre-compute the row of the scale for reuse + int sage_scale_row; + if constexpr (Kernel_traits::SAGE_ATTENTION) { + sage_scale_row = head_info.bidb * params.h + head_info.bidh; + } + + // BMM2 epilogue + Tile_o_epilogue tile_o_epilogue(params, head_info); + + int q_step_idx = warpgroup_id; + + // Compute work. + for (; q_step_idx < q_steps; q_step_idx += NUM_COMPUTE_GROUPS) { + // Check whether it is a valid run of q steps. + int const q_offset = q_step_idx * STEP_Q + local_q_tile_offset; + bool const valid_run = q_offset < actual_q_seqlen; + // fuse the scale of q into scale_bmm1 + if constexpr (SAGE_BLOCK_SIZE_Q > 0) { + // I tried another implementation here: store original `scale_bmm1` to a local variable + // to avoid frequent `__ldg`. But experiment shows that the current one is faster. + // A bit counterintuitive. + auto const scale_bmm1 = + params.scale_bmm1_d ? __ldg(params.scale_bmm1_d) : params.scale_bmm1; + int const idx = sage_scale_row * params.sage.q.max_nblock + q_offset / SAGE_BLOCK_SIZE_Q; + *(float*)(&softmax.scale_bmm1_) = + reinterpret_cast(scale_bmm1) * __ldg(¶ms.sage.q.scales[idx]); + } + + // KV tile is shared by two q tiles, + // so we need to consider the last compute group's q tile. + int const tile_offset_start = q_step_idx * STEP_Q + q_tile_offset; + int const tile_offset_end = tile_offset_start + STEP_Q - 1; + int const warpgroup_tile_offset_start = tile_offset_start - warpgroup_id * STEP_Q; + int const warpgroup_tile_offset_end = + tile_offset_start + (NUM_COMPUTE_GROUPS - warpgroup_id) * STEP_Q - 1; + + // Compute the kv_idx start (inclusive) and end (exclusive). + auto const [kv_idx_start, kv_idx_end] = DMA::Device::compute_kv_tile_idx( + params, warpgroup_tile_offset_start, warpgroup_tile_offset_end, kv_steps); + + // Compute the kv_left_mask_end and kv_right_mask_start, where mask is applied when kv_idx < + // kv_left_mask_end or kv_idx >= kv_right_mask_start. + auto const [kv_left_mask_end, kv_right_mask_start] = + compute_kv_mask_start_end(params, tile_offset_start, tile_offset_end, kv_idx_end); + + // The gmem O tile. + Gmem_tile_o gmem_o(params, head_info, *shared, tidx, + q_step_idx * STEP_Q + local_q_tile_offset); + + // Q ready to use in smem. + int ready = cbr.peek(); + if (!ready) { + cbr.wait(); + } + + static_assert(Mma_tile_p::CORES_M == 2); + float p_max[Mma_tile_p::CORES_M]; + float p_sum[Mma_tile_p::CORES_M]; + + int kv_step_idx = kv_idx_start; + // First K tiles ready to use in smem. + K_TILE_WAIT(); + // Need to apply mask if only kv tile exists. + if (kv_idx_start < kv_left_mask_end || kv_idx_start >= kv_right_mask_start) { + COMPUTE_SINGLE_TILE(true, true); + } else { + COMPUTE_SINGLE_TILE(true, false); + } + KV_TILE_COMPLETE(); + + for (kv_step_idx += 1; kv_step_idx < kv_right_mask_start; ++kv_step_idx) { + // Current step's K tiles ready to use in smem. + K_TILE_WAIT(); + + // Move kv tile to next buffer. + if (D_GROUPS > 1) { + ctile_p.increment_gmma_desc_group(); + } else { + ctile_p.increment_gmma_desc_b_group(); + } + + ctile_o.increment_gmma_desc_group(); + + // Apply the start mask only when sliding window attention is enabled. + if (kv_step_idx < kv_left_mask_end) { + COMPUTE_SINGLE_TILE(false, true); + } else { + COMPUTE_SINGLE_TILE(false, false); + } + + KV_TILE_COMPLETE(); + } + + // Always apply the mask in the end. + for (; kv_step_idx < kv_idx_end; ++kv_step_idx) { + // Current step's K tiles ready to use in smem. + K_TILE_WAIT(); + + // Move kv tile to next buffer. + if (D_GROUPS > 1) { + ctile_p.increment_gmma_desc_group(); + } else { + ctile_p.increment_gmma_desc_b_group(); + } + + ctile_o.increment_gmma_desc_group(); + + COMPUTE_SINGLE_TILE(false, true); + + KV_TILE_COMPLETE(); + } + if (valid_run) { + // Final step's update. + tile_o_epilogue.scale(ctile_o, p_max, p_sum); + // Store o_tile to gmem. + gmem_o.store(ctile_o.acc_); + } + + // Move q, kv to next buffer. + ctile_p.increment_gmma_desc_a_group(); + ctile_p.increment_gmma_desc_b_group(); + ctile_o.increment_gmma_desc_group(); + + if constexpr (Kernel_traits::RETURN_SOFTMAX_STATS) { + using Mma_tile = typename Traits_p::template Mma_tile; + fmha::Softmax_saver_tma saver(params, head_info); + saver.store(p_sum, p_max, sqrtf(params.d), q_step_idx * STEP_Q, valid_run); + } + } + } + } + + //////////////////////////////////////////////////////////////////////////////////////////////// + + template + inline __device__ void compute_single_tile( + Params params, Compute_tile_p& ctile_p, Softmax& softmax, Compute_tile_o& ctile_o, + float (&p_max)[Mma_tile_p::CORES_M], float (&p_sum)[Mma_tile_p::CORES_M], int const tidx, + int const actual_kv_seqlen, float const alibi_head_scale, int const row_offset, + int const col_offset, int const sage_scale_row, Circular_buffer_q_reader& cbr, + Circular_buffer_kv_reader& cbr_v, OrderedMutexAccessor& mutex, bool complete = false) { +// load the scales of K/V from global memory +#define LOAD_SCALES_KV(dst, which, blocks_per_step, block_size) \ + if constexpr (block_size > 0) { \ + const int _start = col_offset / block_size; \ + const float* _src = \ + params.sage.which.scales + sage_scale_row * params.sage.which.max_nblock + _start; \ + const int _end = params.sage.which.max_nblock - _start; \ + _Pragma("unroll") for (int _i = 0; _i < blocks_per_step; _i++) { \ + dst[_i] = _i < _end ? _src[_i] : 1.0f; \ + } \ + } + +#define LOAD_SCALES_K(scales) LOAD_SCALES_KV(scales, k, SAGE_BLOCKS_PER_STEP_K, SAGE_BLOCK_SIZE_K) + +#define LOAD_SCALES_V(scales) LOAD_SCALES_KV(scales, v, SAGE_BLOCKS_PER_STEP_V, SAGE_BLOCK_SIZE_V) + + // Load the needed packed masks. + softmax.load_packed_mask(row_offset, col_offset); + + // experiments show that here is the best place to load scales of K + float scales_k[SAGE_BLOCKS_PER_STEP_K]; + LOAD_SCALES_K(scales_k) + + // Wait until another warpgroup has already executed HGMMA. + if constexpr (ENABLE_MUTEX && Kernel_traits::ELEMENT_BYTES == 2) { + mutex.wait(); + } + + // Ctile_p is only used once by each n step. + ctile_p.clear(); + + // BMM1 (Q x K'). + warpgroup_arrive(); + +// Only single K groups when sizeof(D) <= 128B. +#pragma unroll + for (int kbi = 0; kbi < BMM1_MMAS_K_GROUPS - 1; kbi++) { +#pragma unroll + for (int ki = 0; ki < BMM1_MMAS_K_PER_GROUP; ki++) { + ctile_p.compute(ki, false, ki == BMM1_MMAS_K_PER_GROUP - 1); + } + ctile_p.increment_gmma_desc_group(); + } + +#pragma unroll + for (int ki = 0; ki < BMM1_MMAS_K_PER_GROUP - 1; ki++) { + ctile_p.compute(ki); + } + + ctile_p.compute(BMM1_MMAS_K_PER_GROUP - 1, true, true); + + warpgroup_commit(); + warpgroup_wait<0>(); + + // Arrive when the last tile consumes the q tile. + if (complete) { + cbr.complete(tidx == 0, cbr.ptr()); + cbr.advance(); + } + + if constexpr (ENABLE_MUTEX && Kernel_traits::ELEMENT_BYTES == 2) { + // Notify another warpgroup to execute HGMMA. + mutex.arrive(); + } + if constexpr (ENABLE_MUTEX && Kernel_traits::ELEMENT_BYTES == 1) { + // Wait until another warpgroup has already executed QGMMA. + mutex.named_bar_wait(); + } + + // Fragment p for BMM2 input + Fragment_p frag_p[Mma_tile_o::MMAS_K]; + + // Unpack the elements from bmm1 output to floats. + softmax.unpack(ctile_p); + // apply the scales of K before softmax + if constexpr (SAGE_BLOCK_SIZE_K > 0) { +#pragma unroll + for (int ni = 0; ni < Mma_tile_p::CORES_N; ni++) { + float const scale_k = scales_k[SAGE_BLOCKS_PER_STEP_K * ni / Mma_tile_p::CORES_N]; +#pragma unroll + for (int mi = 0; mi < Mma_tile_p::CORES_M; mi++) { + softmax.elt_[mi][2 * ni] *= scale_k; + softmax.elt_[mi][2 * ni + 1] *= scale_k; + } + } + } + + // Apply the alibi and mask. + softmax.apply_alibi_and_mask(ctile_p, params.alibi_params, alibi_head_scale, + actual_kv_seqlen, row_offset, col_offset); + + // Softmax Exp, max/sum, and update scales. + softmax.compute_and_update_scale(p_max, p_sum); + + // experiments show that here is the best place to load scales of V + float scales_v[SAGE_BLOCKS_PER_STEP_V]; + LOAD_SCALES_V(scales_v) + + // Update flash attention scales and pack it for BMM2 + softmax.pack(ctile_o, frag_p); + + if constexpr (ENABLE_MUTEX && Kernel_traits::ELEMENT_BYTES == 1) { + // Notify another warpgroup to execute QGMMA. + mutex.named_bar_arrive(); + } + + // Wait until v buffer is ready. + int ready = cbr_v.peek(); + if (!ready) { + cbr_v.wait(); + } + + warpgroup_arrive(); + + float last_scale_v; + +// Apply the scale of V to partial result. +// Note 2 points: +// 1. Because the matrix V is quantized along the inner dimension, it is necessary to interrupt +// the MMA workflow after processing each BLOCKS_SIZE_V rows of V and scale the intermediate +// results once. For example, STEP_KV=256, qgmma.K=32, then 256/32=8 MMAs are needs, +// so mma_ki = [0,1,2, ..., 7]. If the BLOCK_SIZE_V=64, then after each 2 qgmmas we should scale +// ctile_o. +// 2. The ctile_o is all zero at the beginning. if we directly apply the scale of V after each 2 +// qgmmas, let's see what happens: +// ctile_o = [0] +// ctile_o = (ctile_o + P0 x V0) * s0 = P0 x V0 * s0 +// ctile_o = (ctile_o + P1 x V1) * s1 = P0 x V0 * s0 * s1 + P1 x V1 * s1 +// ctile_o = (ctile_o + P2 x V2) * s2 = P0 x V0 * s0 * s1 * s2 + P1 x V1 * s1 * s2 + P2 x V2 * +// s2 +// ... +// As you see, the actual scale of a V block is the cumulative product of the scales of all +// later blocks. To solve this, we have to preprocess the scale s[i] of block[i] to s[i]/s[i+1], +// and the final block uses the actual scale. +// But to fetch the next scale in next STEP leads to bad performance. So we apply s[i-1]/s[i] to +// current partial result BEFORE each V block. +#define APPLY_SCALE_V(mma_ki) \ + if constexpr (SAGE_BLOCK_SIZE_V > 0) { \ + if (mma_ki % (Mma_tile_o::MMAS_K / SAGE_BLOCKS_PER_STEP_V) == 0) { \ + float _scale_v = scales_v[SAGE_BLOCKS_PER_STEP_V * mma_ki / Mma_tile_o::MMAS_K]; \ + if (mma_ki != 0) { \ + warpgroup_commit(); \ + warpgroup_wait<0>(); \ + } \ + last_scale_v = _scale_v; \ + } \ + } + +// BMM2 (S * V). +#pragma unroll + for (int kbi = 0; kbi < BMM2_MMAS_K_GROUPS - 1; kbi++) { +#pragma unroll + for (int ki = 0; ki < BMM2_MMAS_K_PER_GROUP; ++ki) { + int const mma_ki = kbi * BMM2_MMAS_K_PER_GROUP + ki; + APPLY_SCALE_V(mma_ki) + ctile_o.fill_frag_a(frag_p[mma_ki]); + ctile_o.compute(ki, false, ki == BMM2_MMAS_K_PER_GROUP - 1); + } + ctile_o.increment_gmma_desc_group(); + } + +#pragma unroll + for (int ki = 0; ki < BMM2_MMAS_K_PER_GROUP - 1; ++ki) { + int const mma_ki = (BMM2_MMAS_K_GROUPS - 1) * BMM2_MMAS_K_PER_GROUP + ki; + APPLY_SCALE_V(mma_ki) + ctile_o.fill_frag_a(frag_p[mma_ki]); + ctile_o.compute(ki); + } + + APPLY_SCALE_V((Mma_tile_o::MMAS_K - 1)) + ctile_o.fill_frag_a(frag_p[Mma_tile_o::MMAS_K - 1]); + ctile_o.compute(Mma_tile_o::MMAS_K - 1, true, true); + + warpgroup_commit(); + warpgroup_wait<0>(); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace ws +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/warpspec/dma.h b/csrc/fmha_v2/fmha/warpspec/dma.h new file mode 100644 index 0000000000..a14ccafdf3 --- /dev/null +++ b/csrc/fmha_v2/fmha/warpspec/dma.h @@ -0,0 +1,874 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once +#include +#include +#include +#include + +#include "fmha/hopper/arrive_wait.h" +#include "fmha/hopper/smem_tile.h" +#include "fmha/utils.h" + +namespace fmha { +namespace ws { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct DMA { + // The shared struct. + using Shared = typename Kernel_traits::Shared; + // The kv buffer writer. + using Circular_buffer_kv_writer = typename Kernel_traits::Circular_buffer_kv_writer; + using Circular_buffer_v_scratch_reader = typename Kernel_traits::Circular_buffer_v_scratch_reader; + using Circular_buffer_v_scratch_writer = typename Kernel_traits::Circular_buffer_v_scratch_writer; + + // The step size of Q loop. + enum { STEP_Q = Kernel_traits::STEP_Q }; + + // The step size of KV loop. + enum { STEP_KV = Kernel_traits::STEP_KV }; + + // The tile size of Q. + enum { TILE_SIZE_Q = STEP_Q * Kernel_traits::D }; + + // The tile size of Q after head_dimension split. + enum { TILE_SIZE_Q_PER_D_GROUP = STEP_Q * Kernel_traits::D_PER_GROUP }; + + // The tile size of K. + enum { TILE_SIZE_K = STEP_KV * Kernel_traits::D }; + + // The tile size of K after head_dimension split. + enum { TILE_SIZE_K_PER_D_GROUP = STEP_KV * Kernel_traits::D_PER_GROUP }; + + // The tile size of V. + enum { TILE_SIZE_V = STEP_KV * Kernel_traits::DV }; + + // The tile size of V after head_dimension split. + enum { TILE_SIZE_V_PER_D_GROUP = TILE_SIZE_K_PER_D_GROUP }; + + // Whether apply causal mask or not. + enum { CAUSAL_MASK = Kernel_traits::CAUSAL_MASK }; + + // Whether use custom mask input or not. + enum { USE_CUSTOM_MASK = Kernel_traits::USE_CUSTOM_MASK }; + + // Whether we skip those masked tiles when causal mask is enabled ? + enum { SKIP_CAUSAL_MASK_TILES = CAUSAL_MASK && !USE_CUSTOM_MASK }; + + // Whether we attend to the specific sliding window or chunk ? + enum { SLIDING_OR_CHUNKED_ATTENTION = Kernel_traits::SLIDING_OR_CHUNKED_ATTENTION }; + + // Is heads interleaved ? + enum { HEADS_INTERLEAVED = Kernel_traits::HEADS_INTERLEAVED }; + + // Named barrier for inter-warpgroup sync + enum { SYNC_BARRIER = Kernel_traits::DMA_SYNC_BARRIER_ID }; + + // The number of compute groups (currently fixed at 2). + enum { NUM_COMPUTE_GROUPS = Kernel_traits::NUM_COMPUTE_GROUPS }; + + // The tile scheduling mode: static (0), dynamic (1) + enum { SCHEDULING_MODE = Kernel_traits::SCHEDULING_MODE }; + + // Whether read from paged kv buffers or not. + enum { PAGED_KV_INPUT = Kernel_traits::PAGED_KV_INPUT }; + + // Whether the dma group transposes the v tile explicitly. + enum { DMA_GROUP_TRANSPOSE_V = Kernel_traits::DMA_GROUP_TRANSPOSE_V }; + + // How many threads get involved in the dma group. + enum { NUM_THREADS_IN_DMA_GROUP = Kernel_traits::NUM_THREADS_IN_DMA_GROUP }; + + // Transpose V + // K is the sequence length dimension (128 for GMMA). The unroll factor is decided according to + // empirical evidence so as to avoid register spill. + enum { K_ = STEP_KV % 128 == 0 ? 128 : 64 }; + + static_assert(STEP_KV % K_ == 0); + using Transposer = + Transposer 128 || SLIDING_OR_CHUNKED_ATTENTION) ? 1 : 2 /* UNROLL */>; + + struct Device { + // Only the warpgroup leader initiates mbarriers & TMA operations. + uint32_t elect_one_; + // The sum_s for q. + int sum_s_q_; + // The sum_s for kv. + int sum_s_kv_; + // Tile id for q tile scheduling + uint32_t tile_id_; + + inline __device__ Device(uint32_t elect_one) : elect_one_(elect_one) {} + + //////////////////////////////////////////////////////////////////////////////////////////// + + // Compute the kv tile idx start (inclusive) and end (exclusive). + static inline __device__ std::pair compute_kv_tile_idx( + bert::Fused_multihead_attention_params_v2 const& params, int q_step_offset, int q_step_end, + int kv_steps) { + // The default kv_idx_start and kv_idx_end (exclusive). + int kv_idx_start = 0; + int kv_idx_end = kv_steps; + + // Is the chunked_attention used ? + bool is_chunked_attention = params.log2_chunked_attention_size > 0; + + // Skip initial kv tiles due to sliding_window_size + if (SLIDING_OR_CHUNKED_ATTENTION) { + // The kv_offset_start. + int kv_offset_start = is_chunked_attention + ? ((q_step_offset >> params.log2_chunked_attention_size) + << params.log2_chunked_attention_size) + : max(0, q_step_offset + 1 - params.sliding_window_size); + kv_idx_start = kv_offset_start / STEP_KV; + } + + // Early stop when causal mask is enabled. + if (SKIP_CAUSAL_MASK_TILES) { + kv_idx_end = (q_step_end + STEP_KV - 1) / STEP_KV; + } + + return std::make_pair(kv_idx_start, kv_idx_end); + } + + //////////////////////////////////////////////////////////////////////////////////////////// + + // Packed contiguous QKV input. + inline __device__ void run_packed_qkv(bert::Fused_multihead_attention_params_v2 const& params, + Shared* shared) { + // DMA. + int local_wid = (threadIdx.x / 32) % 4; + int tiw = threadIdx.x % 32; + uint32_t smem_tile_id = __cvta_generic_to_shared(&shared->tile_id); + + if (SCHEDULING_MODE == 0) { + tile_id_ = blockIdx.y; + } else { + get_next_tile_id(local_wid, tiw, smem_tile_id, params.tile_id_counter_ptr); + } + + auto cbw0 = shared->tma_q_tracker[0].createWriter(); + auto cbw1 = shared->tma_q_tracker[1].createWriter(); + Circular_buffer_kv_writer cbw_k = shared->tma_k_tracker.createWriter(); + Circular_buffer_kv_writer cbw_v = shared->tma_v_tracker.createWriter(); + Circular_buffer_v_scratch_reader cbr_v_scratch = shared->tma_v_scratch_tracker.createReader(); + Circular_buffer_v_scratch_writer cbw_v_scratch = shared->tma_v_scratch_tracker.createWriter(); + auto headinfo_tracker0 = shared->head_info_tracker[0].createWriter(); + auto headinfo_tracker1 = shared->head_info_tracker[1].createWriter(); + + while (tile_id_ < params.num_tiles) { + // If we do bidh = next_head % h, we'd guarantee b to be spread across CTAs. + + int bidb, tmp, bidh, q_step_offset, q_steps; + + if (SCHEDULING_MODE == 0) { + bidh = tile_id_ % params.h; + bidb = tile_id_ / params.h; + } else { + // Balanced dynamic scheduling + if (CAUSAL_MASK && !SLIDING_OR_CHUNKED_ATTENTION && params.use_balanced_scheduling) { + q_step_offset = (params.num_tiles_per_head - 1 - tile_id_ / (params.b * params.h)) * + NUM_COMPUTE_GROUPS; + tmp = tile_id_ % (params.b * params.h); + bidh = tmp / params.b; + bidb = tmp % params.b; + q_steps = NUM_COMPUTE_GROUPS; + } else { // Unbalanced dynamic scheduling + bidb = tile_id_ / (params.h * params.num_tiles_per_head); + tmp = tile_id_ % (params.h * params.num_tiles_per_head); + bidh = tmp / params.num_tiles_per_head; + q_step_offset = tmp % params.num_tiles_per_head * NUM_COMPUTE_GROUPS; + q_steps = NUM_COMPUTE_GROUPS; + } + } + + cudaTmaDesc const* desc_q = ¶ms.tma_desc_q; + cudaTmaDesc const* desc_k = ¶ms.tma_desc_k; + cudaTmaDesc const* desc_v = ¶ms.tma_desc_v; + int actual_seqlen; + if (params.is_s_padded) { + sum_s_q_ = bidb * params.s; + actual_seqlen = params.cu_q_seqlens[bidb + 1] - params.cu_q_seqlens[bidb]; + } else { + sum_s_q_ = params.cu_q_seqlens[bidb]; + actual_seqlen = params.cu_q_seqlens[bidb + 1] - sum_s_q_; + } + sum_s_kv_ = sum_s_q_; + + // The cumulative packed_mask seqlens. + // Each sequence length in the batch has to be padded to multiple of 128. + int sum_mask_s = params.cu_mask_rows[bidb]; + + if (SCHEDULING_MODE == 0) { + // split work across M + q_steps = (actual_seqlen + STEP_Q - 1) / STEP_Q; + + // Q_steps may be distributed to multiple blocks to increase the occupacy + // when b*h is small. + // The number of q_steps needs to be multiple of 2. + q_steps = (q_steps + gridDim.x - 1) / gridDim.x; + q_steps += (q_steps & 1); + // The last block may process fewer q_steps. + q_step_offset = q_steps * blockIdx.x; + } + + int q_tile_offset = q_step_offset * STEP_Q; + if (q_tile_offset >= actual_seqlen) { + if (SCHEDULING_MODE == 0) { + tile_id_ += gridDim.y; + } else { + get_next_tile_id(local_wid, tiw, smem_tile_id, params.tile_id_counter_ptr); + } + continue; + } + + // Split work across N. + int const kv_steps = (actual_seqlen + STEP_KV - 1) / STEP_KV; + for (int q_step_idx = 0; q_step_idx < q_steps; q_step_idx += 2) { + load_q(bidh, (q_step_idx + 0 + q_step_offset) * STEP_Q, desc_q, shared->smem_q[0], cbw0); + load_q(bidh, (q_step_idx + 1 + q_step_offset) * STEP_Q, desc_q, shared->smem_q[1], cbw1); + + // Q step bound is 2 tiles away at this moment because of 2x1 math warpgroup + int const q_step_end = (q_step_idx + q_step_offset + 2) * STEP_Q - 1; + + // The kv tile idx range for this q step. + auto const [kv_idx_start, kv_idx_end] = compute_kv_tile_idx( + params, (q_step_idx + q_step_offset) * STEP_Q, q_step_end, kv_steps); + + // Iterate over the kv tiles for this q step. + for (int kv_step_idx = kv_idx_start; kv_step_idx < kv_idx_end; kv_step_idx++) { + int bar_id = load_kv(bidh / params.h_q_per_kv, kv_step_idx * STEP_KV, desc_k, desc_v, + shared, cbw_k, cbw_v, cbw_v_scratch); + + // Opportunistically hide headinfo in the shadow of UTMALDGs of the QKV tensor + if (q_step_idx == 0 && kv_step_idx == kv_idx_start) { + // Send head info. + typename Shared::Head_info info{ + q_steps, + // q, and kv have the same length. + q_tile_offset, USE_CUSTOM_MASK ? sum_mask_s : q_tile_offset, kv_steps, + // q, and kv have the same length. + actual_seqlen, actual_seqlen, sum_s_q_ * params.h + bidh, bidh, bidb}; + // NOTE(tizheng): The need for the sync after consumer bar wait is to avoid a deadlock + // hazard when DMA thread 0 is ahead of other DMA threads. For example: DMA thread 0 + // have finished consumer bar wait phase 0 and producer bar arrive phase 0, and then + // MMA warps have finished producer bar wait phase 0 and consumer bar arrive phase 1. + // At this time other DMA threads start consumer bar wait phase 0. It will never + // become ready. DMA warps then fail to continue to the next loop. + // + // It is the same consideration for the sync after tmaReserve in load_q and load_kv + // implementation below. + headinfo_tracker0.template push_with_sync( + elect_one_, info); + headinfo_tracker1.template push_with_sync( + elect_one_, info); + } + + if constexpr (DMA_GROUP_TRANSPOSE_V) { + transpose_v_tile(bar_id, shared, cbw_v, cbr_v_scratch); + } + } // kv + } // q + + if (SCHEDULING_MODE == 0) { + tile_id_ += gridDim.y; + } else { + get_next_tile_id(local_wid, tiw, smem_tile_id, params.tile_id_counter_ptr); + } + } // gridDim.y + // Signal compute groups to break. + headinfo_tracker0.template push_with_sync( + elect_one_, {-1, -1, -1, -1, -1, -1, -1, -1}); + headinfo_tracker1.template push_with_sync( + elect_one_, {-1, -1, -1, -1, -1, -1, -1, -1}); + } + + // Support contiguous Q + contiguous/paged KV separate cache. + inline __device__ void run_separate_q_and_kv( + bert::Fused_multihead_attention_params_v2 const& params, Shared* shared) { + // DMA. + int local_wid = (threadIdx.x / 32) % 4; + int tiw = threadIdx.x % 32; + uint32_t smem_tile_id = __cvta_generic_to_shared(&shared->tile_id); + + if (SCHEDULING_MODE == 0) { + tile_id_ = blockIdx.y; + } else { + get_next_tile_id(local_wid, tiw, smem_tile_id, params.tile_id_counter_ptr); + } + + auto cbw0 = shared->tma_q_tracker[0].createWriter(); + auto cbw1 = shared->tma_q_tracker[1].createWriter(); + Circular_buffer_kv_writer cbw_k = shared->tma_k_tracker.createWriter(); + Circular_buffer_kv_writer cbw_v = shared->tma_v_tracker.createWriter(); + Circular_buffer_v_scratch_reader cbr_v_scratch = shared->tma_v_scratch_tracker.createReader(); + Circular_buffer_v_scratch_writer cbw_v_scratch = shared->tma_v_scratch_tracker.createWriter(); + auto headinfo_tracker0 = shared->head_info_tracker[0].createWriter(); + auto headinfo_tracker1 = shared->head_info_tracker[1].createWriter(); + + while (tile_id_ < params.num_tiles) { + // If we do bidh = next_head % h, we'd guarantee b to be spread across CTAs. + + int bidb, tmp, bidh, local_q_tile_offset, q_steps; + + if (SCHEDULING_MODE == 0) { + bidh = tile_id_ % params.h; + bidb = tile_id_ / params.h; + } else if (SCHEDULING_MODE == 1) { + bidb = tile_id_ / (params.h * params.num_tiles_per_head); + tmp = tile_id_ % (params.h * params.num_tiles_per_head); + bidh = tmp / params.num_tiles_per_head; + local_q_tile_offset = (tmp % params.num_tiles_per_head) * NUM_COMPUTE_GROUPS * STEP_Q; + q_steps = NUM_COMPUTE_GROUPS; + } else { // SCHEDULING_MODE == 2 + local_q_tile_offset = (params.num_tiles_per_head - 1 - tile_id_ / (params.b * params.h)) * + NUM_COMPUTE_GROUPS * STEP_Q; + tmp = tile_id_ % (params.b * params.h); + bidh = tmp / params.b; + bidb = tmp % params.b; + q_steps = NUM_COMPUTE_GROUPS; + } + int bidh_kv = bidh / params.h_q_per_kv; + + // Sequence length parameters. + // Take chunked attention (q, and kv may have difference sequence length) into + // consideration. + sum_s_q_ = params.is_s_padded ? bidb * params.s : params.cu_q_seqlens[bidb]; + sum_s_kv_ = params.is_s_padded ? bidb * params.s : params.cu_kv_seqlens[bidb]; + int actual_q_seqlen = params.cu_q_seqlens[bidb + 1] - params.cu_q_seqlens[bidb]; + int actual_kv_seqlen = params.cu_kv_seqlens[bidb + 1] - params.cu_kv_seqlens[bidb]; + int past_kv_length = actual_kv_seqlen - actual_q_seqlen; + + // The cumulative packed_mask seqlens. + // Each sequence length in the batch has to be padded to multiple of 128. + int sum_mask_s = params.cu_mask_rows[bidb]; + + // Prepare the tma descriptors. + cudaTmaDesc const* desc_q = ¶ms.tma_desc_q; + cudaTmaDesc const* desc_k = ¶ms.tma_desc_k; + cudaTmaDesc const* desc_v = ¶ms.tma_desc_v; + + int32_t const* paged_block_offsets = + params.paged_kv_cache.mBlockOffsets + bidb * 2 * params.paged_kv_cache.mMaxBlocksPerSeq; + + if (SCHEDULING_MODE == 0) { + // split work across M + q_steps = (actual_q_seqlen + STEP_Q - 1) / STEP_Q; + + // Q_steps may be distributed to multiple blocks to increase the occupacy + // when b*h is small. + // The number of q_steps needs to be multiple of 2. + q_steps = (q_steps + gridDim.x - 1) / gridDim.x; + q_steps += (q_steps & 1); + local_q_tile_offset = q_steps * blockIdx.x * STEP_Q; + } + + // The last block may process fewer q_steps. + if (local_q_tile_offset >= actual_q_seqlen) { + if (SCHEDULING_MODE == 0) { + tile_id_ += gridDim.y; + } else { + get_next_tile_id(local_wid, tiw, smem_tile_id, params.tile_id_counter_ptr); + } + continue; + } + + // The global q tile offset which includes the past kv cache. + int q_tile_offset = local_q_tile_offset + past_kv_length; + // Split work across N. + int const kv_steps = (actual_kv_seqlen + STEP_KV - 1) / STEP_KV; + // Page KV: number of valid kv blocks (others might be nullptr). + int const num_valid_kv_blocks = + (actual_kv_seqlen + params.paged_kv_cache.mTokensPerBlock - 1) >> + params.paged_kv_cache.mTokensPerBlockLog2; + + for (int q_step_idx = 0; q_step_idx < q_steps && actual_kv_seqlen > 0; q_step_idx += 2) { + load_q(bidh, q_step_idx * STEP_Q + local_q_tile_offset, desc_q, shared->smem_q[0], cbw0); + load_q(bidh, (q_step_idx + 1) * STEP_Q + local_q_tile_offset, desc_q, shared->smem_q[1], + cbw1); + + // Q step end is 2 tiles away at this moment because of 2x1 math warpgroup + int const q_step_end = (q_step_idx + 2) * STEP_Q - 1 + q_tile_offset; + + // The kv tile idx range for this q step. + auto const [kv_idx_start, kv_idx_end] = compute_kv_tile_idx( + params, q_step_idx * STEP_Q + q_tile_offset, q_step_end, kv_steps); + + // Iterate over the kv tiles for this q step. + for (int kv_step_idx = kv_idx_start; kv_step_idx < kv_idx_end; kv_step_idx++) { + // The barrier id. + int bar_id; + // Load paged kv input. + if constexpr (PAGED_KV_INPUT) { + bar_id = load_paged_kv(bidh_kv, kv_step_idx * STEP_KV, num_valid_kv_blocks, + params.paged_kv_cache.mTokensPerBlockLog2, + params.blocks_per_tma_load, params.blocks_per_tma_load_log2, + params.paged_kv_cache.mMaxBlocksPerSeq, paged_block_offsets, + desc_k, desc_v, shared, cbw_k, cbw_v, cbw_v_scratch); + } else { + bar_id = load_kv(bidh_kv, kv_step_idx * STEP_KV, desc_k, desc_v, shared, cbw_k, cbw_v, + cbw_v_scratch); + } + + // Opportunistically hide headinfo in the shadow of UTMALDGs of the QKV tensor + if (q_step_idx == 0 && kv_step_idx == kv_idx_start) { + // Send head info. + typename Shared::Head_info info{q_steps, + local_q_tile_offset, + USE_CUSTOM_MASK ? sum_mask_s : q_tile_offset, + kv_steps, + actual_q_seqlen, + actual_kv_seqlen, + sum_s_q_ * params.h + bidh, + bidh, + bidb}; + headinfo_tracker0.template push_with_sync( + elect_one_, info); + headinfo_tracker1.template push_with_sync( + elect_one_, info); + } + if constexpr (DMA_GROUP_TRANSPOSE_V) { + transpose_v_tile(bar_id, shared, cbw_v, cbr_v_scratch); + } + } // kv + } // q + + if (SCHEDULING_MODE == 0) { + tile_id_ += gridDim.y; + } else { + get_next_tile_id(local_wid, tiw, smem_tile_id, params.tile_id_counter_ptr); + } + } // gridDim.y + + // Signal compute groups to break. + headinfo_tracker0.template push_with_sync( + elect_one_, {-1, -1, -1, -1, -1, -1, -1, -1}); + headinfo_tracker1.template push_with_sync( + elect_one_, {-1, -1, -1, -1, -1, -1, -1, -1}); + } + + // Load q tiles from gmem to smem by TMA. + template + inline __device__ void load_q(int bidh, int q_tile_start_offset, cudaTmaDesc const* desc_q, + Smem_q& smem_q, BufferWriter& cbw) { + int barrier_id = cbw.tmaReserve(elect_one_, TILE_SIZE_Q * Kernel_traits::ELEMENT_BYTES); + + named_barrier_wait(SYNC_BARRIER, NUM_THREADS_IN_DMA_GROUP); + + // split D into multiple groups in order to satisfy the TMA 128B sizzle mode +#pragma unroll + for (int di = 0; di < Kernel_traits::D_GROUPS; ++di) { + const int32_t coords[3] = {di * Kernel_traits::D_PER_GROUP, bidh, + sum_s_q_ + q_tile_start_offset}; + fmha::utmaldg<3, fmha::cudaTmaDescType::TILED, false>( + desc_q, + __cvta_generic_to_shared( + &smem_q[barrier_id * TILE_SIZE_Q + di * TILE_SIZE_Q_PER_D_GROUP]), + __cvta_generic_to_shared(cbw.barrier_ptr(barrier_id)), coords, elect_one_); + } + } + +#define PREPARE_KV_BUFFER() \ + int k_barrier_id = cbw_k.tmaReserve(elect_one_, (TILE_SIZE_K) * Kernel_traits::ELEMENT_BYTES); \ + \ + int v_barrier_id; \ + void* v_barrier_ptr; \ + typename Kernel_traits::Element_data_type* v_smem; \ + \ + if constexpr (DMA_GROUP_TRANSPOSE_V) { \ + v_barrier_id = \ + cbw_v_scratch.tmaReserve(elect_one_, (TILE_SIZE_V) * Kernel_traits::ELEMENT_BYTES); \ + v_barrier_ptr = cbw_v_scratch.barrier_ptr(v_barrier_id); \ + v_smem = shared->smem_v_scratch.data(); \ + } else { \ + v_barrier_id = cbw_v.tmaReserve(elect_one_, (TILE_SIZE_V) * Kernel_traits::ELEMENT_BYTES); \ + v_barrier_ptr = cbw_v.barrier_ptr(v_barrier_id); \ + v_smem = shared->smem_v.data(); \ + } \ + \ + named_barrier_wait(SYNC_BARRIER, NUM_THREADS_IN_DMA_GROUP); + + // Load k,v tiles from gmem to smem by TMA. + template + inline __device__ int load_kv(int bidh_kv, int kv_tile_start_offset, cudaTmaDesc const* desc_k, + cudaTmaDesc const* desc_v, Shared* shared, BufferWriter& cbw_k, + BufferWriter& cbw_v, BufferWriterScratch& cbw_v_scratch) { + PREPARE_KV_BUFFER() + + // split D into multiple groups in order to satisfy the TMA 128B sizzle mode +#pragma unroll + for (int di = 0; di < Kernel_traits::D_GROUPS; ++di) { + const int32_t k_coords[3] = {di * Kernel_traits::D_PER_GROUP, bidh_kv, + sum_s_kv_ + kv_tile_start_offset}; + + fmha::utmaldg<3, fmha::cudaTmaDescType::TILED, false>( + desc_k, + __cvta_generic_to_shared( + &shared->smem_k[k_barrier_id * TILE_SIZE_K + di * TILE_SIZE_K_PER_D_GROUP]), + __cvta_generic_to_shared(cbw_k.barrier_ptr(k_barrier_id)), k_coords, elect_one_); + } + +#pragma unroll + for (int di = 0; di < Kernel_traits::DV_GROUPS; ++di) { + const int32_t v_coords[3] = {di * Kernel_traits::D_PER_GROUP, bidh_kv, + sum_s_kv_ + kv_tile_start_offset}; + + fmha::utmaldg<3, fmha::cudaTmaDescType::TILED, false>( + desc_v, + __cvta_generic_to_shared( + &v_smem[v_barrier_id * TILE_SIZE_V + di * TILE_SIZE_V_PER_D_GROUP]), + __cvta_generic_to_shared(v_barrier_ptr), v_coords, elect_one_); + } + + return v_barrier_id; + } + + // Load paged k,v tiles from gmem to smem by TMA. + template + inline __device__ int load_paged_kv(int bidh_kv, int kv_tile_start_offset, + int num_valid_kv_blocks, int tokens_per_block_log2, + int blocks_per_tma_load, int blocks_per_tma_load_log2, + int max_blocks_per_sequence, + int32_t const* paged_block_offsets, + cudaTmaDesc const* desc_k, cudaTmaDesc const* desc_v, + Shared* shared, BufferWriter& cbw_k, BufferWriter& cbw_v, + BufferWriterScratch& cbw_v_scratch) { + PREPARE_KV_BUFFER() + + // Paged KV cache block idx. + int paged_kv_block_idx = kv_tile_start_offset >> tokens_per_block_log2; + int kv_offset_in_block = kv_tile_start_offset & ((1 << tokens_per_block_log2) - 1); + + // coordinates: d, s, h, 1 + int const tile_size_k_per_block = TILE_SIZE_K_PER_D_GROUP >> blocks_per_tma_load_log2; + static_assert(TILE_SIZE_V_PER_D_GROUP == TILE_SIZE_K_PER_D_GROUP, + "KV tile should have the same tensor size."); + for (int bi = 0; bi < blocks_per_tma_load; ++bi) { + int const bounded_block_idx = min(num_valid_kv_blocks - 1, paged_kv_block_idx + bi); + + const int32_t k_paged_block_offset = paged_block_offsets[bounded_block_idx]; + const int32_t v_paged_block_offset = + paged_block_offsets[max_blocks_per_sequence + bounded_block_idx]; + +#pragma unroll + for (int di = 0; di < Kernel_traits::D_GROUPS; ++di) { + const int32_t k_coords[4] = {di * Kernel_traits::D_PER_GROUP, kv_offset_in_block, bidh_kv, + k_paged_block_offset}; + + fmha::utmaldg<4, fmha::cudaTmaDescType::TILED, false>( + desc_k, + __cvta_generic_to_shared( + &shared->smem_k[k_barrier_id * TILE_SIZE_K + di * TILE_SIZE_K_PER_D_GROUP + + bi * tile_size_k_per_block]), + __cvta_generic_to_shared(cbw_k.barrier_ptr(k_barrier_id)), k_coords, elect_one_); + } + +#pragma unroll + for (int di = 0; di < Kernel_traits::DV_GROUPS; ++di) { + const int32_t v_coords[4] = {di * Kernel_traits::D_PER_GROUP, kv_offset_in_block, bidh_kv, + v_paged_block_offset}; + + fmha::utmaldg<4, fmha::cudaTmaDescType::TILED, false>( + desc_v, + __cvta_generic_to_shared( + &v_smem[v_barrier_id * TILE_SIZE_V + di * TILE_SIZE_V_PER_D_GROUP + + bi * tile_size_k_per_block]), + __cvta_generic_to_shared(v_barrier_ptr), v_coords, elect_one_); + } + } + + return v_barrier_id; + } + + template + // Transpose v tile explicitly as QGMMA doesn't support it. + inline __device__ void transpose_v_tile(int v_scratch_barrier_id, Shared* shared, + BufferWriter& cbw_v, + BufferReaderScratch& cbr_v_scratch) { + static_assert(NUM_THREADS_IN_DMA_GROUP == 128, ""); + Transposer transposer(threadIdx.x % NUM_THREADS_IN_DMA_GROUP); + + // Src buffer available + int ready = cbr_v_scratch.peek(); + if (!ready) { + cbr_v_scratch.wait(); + } + uint32_t smem_v_src = __cvta_generic_to_shared(&shared->smem_v_scratch[v_scratch_barrier_id]); + + // Dst buffer available + int v_barrier_id = cbw_v.threadReserve(); + uint32_t smem_v_dst = __cvta_generic_to_shared(&shared->smem_v[v_barrier_id * TILE_SIZE_V]); + +// Explicitly transpose the v buffer in smem for fp8. + +// The transposer currently has support of the following tile sizes: +// - D=32, S (or KV_STEP)=128 +// - D=64, S (or KV_STEP)=64, 128 +// - D=128, S (or KV_STEP)=64, 128 +// In addition, the transposer can only work with contiguous chunk of SMEM. +// +// For example, if V tile size is D=256 S=256, we can divide the TMA load of the V tile +// (SxD) into 2x2 chunks of size 128x128. This way, when tiles (0, 0), (0, 1) are transposed, +// either the load and the store of the data can be performed in a contiguous memory. +// +// Keep in mind in order to match GMMA requirement, we need to store the transposed tiles +// along D dim first then S dim. Leading dimension S after the transpose is at most 128B. +// +// Logical: +// D - D I M (contiguous) +// +// 128 128 S +// <------------> <------------> - +// s, d = (0, 0) | s, d = (0, 1) D +// ------------------------------ I +// s, d = (1, 0) | s, d = (1, 1) M +// +// In SMEM: +// D - D I M +// +// 128 128 128 128 S +// <------------> <-------------> <-------------> <------------> - +// s, d = (0, 0) | s, d = (0, 1) | s, d = (1, 0) | s, d = (1, 1) D (contiguous) +// I +// M +// +#pragma unroll + for (int kgroup_idx = 0; kgroup_idx < Kernel_traits::BMM2_K_GROUPS; kgroup_idx++) { +#pragma unroll + for (int dgroup_idx = 0; dgroup_idx < Kernel_traits::DV_GROUPS; dgroup_idx++) { + // Src smem block is k first then d + uint32_t src_offset = + (kgroup_idx * Kernel_traits::BMM2_K_PER_GROUP * Kernel_traits::D_PER_GROUP + + dgroup_idx * Kernel_traits::D_PER_GROUP * Kernel_traits::STEP_KV) * + Kernel_traits::ELEMENT_BYTES; + + // Dst smem block is d first then k + uint32_t dst_offset = + (dgroup_idx * Kernel_traits::BMM2_K_PER_GROUP * Kernel_traits::D_PER_GROUP + + kgroup_idx * Kernel_traits::BMM2_K_PER_GROUP * Kernel_traits::DV) * + Kernel_traits::ELEMENT_BYTES; + + transposer.template transpose_(smem_v_src + src_offset, smem_v_dst + dst_offset); + } + } + + fence_view_async_shared(); // Commit STSM + named_barrier_wait(SYNC_BARRIER, NUM_THREADS_IN_DMA_GROUP); // Sync before signaling + cbw_v.threadCommit(elect_one_, v_barrier_id); // Signal readiness + cbr_v_scratch.pop(elect_one_); // Advance to next phase + } + + inline __device__ void get_next_tile_id(int local_wid, int tiw, uint32_t smem_tile_id, + uint32_t* tile_id_counter_ptr) { + if constexpr (DMA_GROUP_TRANSPOSE_V) { + if (elect_one_) { + tile_id_ = atomicAdd(tile_id_counter_ptr, 1); + sts(smem_tile_id, tile_id_); + } + fence_view_async_shared(); + named_barrier_wait(SYNC_BARRIER, 128); + if (tiw == 0) { + lds(tile_id_, smem_tile_id); + } + tile_id_ = __shfl_sync(0xffffffff, tile_id_, 0); + // only one warp involved when the dma group doesn't need to transpose the v tile. + } else { + if (elect_one_) { + tile_id_ = atomicAdd(tile_id_counter_ptr, 1); + } + tile_id_ = __shfl_sync(0xffffffff, tile_id_, 0); + } + } + }; + + //////////////////////////////////////////////////////////////////////////////////////////////// + + struct Host { + Host() {} + + // Set TMA descriptors on host, and launch as __grid_constant__. + // Paged KV FMHA parameters. + void init_params(bert::Fused_multihead_attention_params_v2& params, + bert::Fused_multihead_attention_launch_params const& launch_params, + cudaStream_t stream) const { + const uint32_t d = params.d; + const uint32_t dv = params.dv; + const uint32_t h = params.h; + const uint32_t h_kv = params.h_kv; + + // Total sequence length. + const uint32_t total_seqlen = + params.is_s_padded ? (params.b * params.s) : launch_params.total_q_seqlen; + + // O Layout: [total_seqlen, H, DV] + // Per batch tensor size. + uint32_t tensor_size_o[3] = {dv, h, total_seqlen}; + + // Stride size in bytes. Assumes least significant dim is 1 + uint64_t tensor_stride_o[2] = {dv * Kernel_traits::ELEMENT_BYTES, + uint64_t(params.o_stride_in_bytes)}; + + // Starting memory address + char* o_ptr = reinterpret_cast(params.o_ptr); + + // Box size of TMA + uint32_t box_size_o[3] = {Kernel_traits::D_PER_GROUP, 1, 16}; + + // Traversal stride. + uint32_t traversal_stride[3] = {1, 1, 1}; + + // OOB fill zeros. + uint32_t oob_fill = 0; + + // FP32 to TF32 conversion disabled. + uint32_t fp32_to_tf32 = 0; + + // GMMA descriptor mode. + static constexpr int D_BYTES_PER_GROUP = Kernel_traits::D_BYTES_PER_GROUP; + static constexpr fmha::cudaTmaDescSwizzle swizzle_mode = + (D_BYTES_PER_GROUP > 64 ? fmha::cudaTmaDescSwizzle::SWIZZLE_128B + : D_BYTES_PER_GROUP > 32 ? fmha::cudaTmaDescSwizzle::SWIZZLE_64B + : fmha::cudaTmaDescSwizzle::SWIZZLE_32B); + + static_assert(STEP_KV <= 256 && STEP_Q <= 256, "max box size is 256"); + + // Desc Format (data type). + static constexpr fmha::cudaTmaDescFormat desc_format = (Kernel_traits::ELEMENT_BYTES == 1) + ? fmha::cudaTmaDescFormat::U8 + : fmha::cudaTmaDescFormat::F16_RN; + + fmha::Multiple_tma_descriptor<3> qo_tma_descriptor; + + // TMA O + if (Kernel_traits::USE_TMA_STORE) { + qo_tma_descriptor.set_tma_desctriptor( + o_ptr, desc_format, fmha::cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode, + fmha::cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_o, tensor_stride_o, + traversal_stride, box_size_o, oob_fill, fp32_to_tf32, ¶ms.tma_desc_o); + } + + auto const layout = launch_params.attention_input_layout; + + // Q always uses 3D tensor + uint32_t tensor_size_q[3] = {d, h, total_seqlen}; + + uint64_t tensor_stride_q[2] = {d * Kernel_traits::ELEMENT_BYTES, + uint64_t(params.q_stride_in_bytes)}; + + char* q_ptr = reinterpret_cast( + layout == fmha::Attention_input_layout::PACKED_QKV ? params.qkv_ptr : params.q_ptr); + + uint32_t box_size_q[3] = {Kernel_traits::D_PER_GROUP, 1, STEP_Q}; + + if (layout == fmha::Attention_input_layout::Q_PAGED_KV) { + // KV in q_paged_kv uses 4D tensor + // Layout: [INT32_MAX, H_KV, TokensPerBlock, D] + const uint32_t tokens_per_block = params.paged_kv_cache.mTokensPerBlock; + uint32_t tensor_size_k[4] = {d, tokens_per_block, h_kv, INT_MAX}; + uint32_t tensor_size_v[4] = {dv, tokens_per_block, h_kv, INT_MAX}; + + uint64_t tensor_stride_k[3]; + tensor_stride_k[0] = params.k_stride_in_bytes / tokens_per_block; // d + tensor_stride_k[1] = params.k_stride_in_bytes; // d * 64 + tensor_stride_k[2] = params.paged_kv_cache.mBytesPerBlock; + uint64_t tensor_stride_v[3]; + // we cannot use dv * Kernel_traits::ELEMENT_BYTES because V may be padded (MLA) + tensor_stride_v[0] = params.v_stride_in_bytes / tokens_per_block; // dv + tensor_stride_v[1] = params.v_stride_in_bytes; // dv * 64 + tensor_stride_v[2] = params.paged_kv_cache.mBytesPerBlock; + + char* kv_ptr = reinterpret_cast(params.paged_kv_cache.mPoolPtr); + + uint32_t box_size_kv[4] = {Kernel_traits::D_PER_GROUP, + std::min(tokens_per_block, STEP_KV), 1, 1}; + + assert(STEP_KV % tokens_per_block == 0 || tokens_per_block % STEP_KV == 0); + params.blocks_per_tma_load = std::max(1, STEP_KV / tokens_per_block); + params.blocks_per_tma_load_log2 = log2(params.blocks_per_tma_load); + + uint32_t traversal_stride[4] = {1, 1, 1, 1}; + + fmha::Multiple_tma_descriptor<4> kv_tma_descriptor; + // K + kv_tma_descriptor.set_tma_desctriptor( + kv_ptr, desc_format, fmha::cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode, + fmha::cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_k, tensor_stride_k, + traversal_stride, box_size_kv, oob_fill, fp32_to_tf32, ¶ms.tma_desc_k); + // V + kv_tma_descriptor.set_tma_desctriptor( + kv_ptr, desc_format, fmha::cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode, + fmha::cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_v, tensor_stride_v, + traversal_stride, box_size_kv, oob_fill, fp32_to_tf32, ¶ms.tma_desc_v); + } else { + // Otherwise KV uses 3D tensor + uint32_t tensor_size_k[3] = {d, h_kv, total_seqlen}; + uint32_t tensor_size_v[3] = {dv, h_kv, total_seqlen}; + + uint64_t tensor_stride_k[2] = {d * Kernel_traits::ELEMENT_BYTES, + uint64_t(params.k_stride_in_bytes)}; + uint64_t tensor_stride_v[2] = {dv * Kernel_traits::ELEMENT_BYTES, + uint64_t(params.v_stride_in_bytes)}; + + uint32_t box_size_kv[3] = {Kernel_traits::D_PER_GROUP, 1, STEP_KV}; + + char *k_ptr, *v_ptr; + + if (layout == fmha::Attention_input_layout::PACKED_QKV) { + if (!HEADS_INTERLEAVED || h != h_kv) { + // Layout: [total_seqlen, (H, D) + (H_KV, D) + (H_KV, DV)] + // All of MHA in TRTLLM is in this layout, + // and MQA/GQA must use this layout. + k_ptr = q_ptr + h * d * Kernel_traits::ELEMENT_BYTES; + v_ptr = k_ptr + h_kv * d * Kernel_traits::ELEMENT_BYTES; + } else { + // Layout: [total_seqlen, H, D + D + DV] + // Currently only used in MHA in fmha_v2 tests. + tensor_stride_q[0] = tensor_stride_k[0] = tensor_stride_v[0] = + (2 * d + dv) * Kernel_traits::ELEMENT_BYTES; + k_ptr = q_ptr + d * Kernel_traits::ELEMENT_BYTES; + v_ptr = k_ptr + d * Kernel_traits::ELEMENT_BYTES; + } + } else if (layout == fmha::Attention_input_layout::CONTIGUOUS_Q_KV) { + k_ptr = reinterpret_cast(params.kv_ptr); + v_ptr = k_ptr + h_kv * d * Kernel_traits::ELEMENT_BYTES; + } else if (layout == fmha::Attention_input_layout::SEPARATE_Q_K_V) { + k_ptr = reinterpret_cast(params.k_ptr); + v_ptr = reinterpret_cast(params.v_ptr); + } + + fmha::Multiple_tma_descriptor<3> kv_tma_descriptor; + // K + kv_tma_descriptor.set_tma_desctriptor( + k_ptr, desc_format, fmha::cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode, + fmha::cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_k, tensor_stride_k, + traversal_stride, box_size_kv, oob_fill, fp32_to_tf32, ¶ms.tma_desc_k); + // V + kv_tma_descriptor.set_tma_desctriptor( + v_ptr, desc_format, fmha::cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode, + fmha::cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_v, tensor_stride_v, + traversal_stride, box_size_kv, oob_fill, fp32_to_tf32, ¶ms.tma_desc_v); + } + // Q + qo_tma_descriptor.set_tma_desctriptor( + q_ptr, desc_format, fmha::cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode, + fmha::cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_q, tensor_stride_q, + traversal_stride, box_size_q, oob_fill, fp32_to_tf32, ¶ms.tma_desc_q); + } + }; +}; + +} // namespace ws +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/warpspec/epilogue.h b/csrc/fmha_v2/fmha/warpspec/epilogue.h new file mode 100644 index 0000000000..15f8636207 --- /dev/null +++ b/csrc/fmha_v2/fmha/warpspec/epilogue.h @@ -0,0 +1,1091 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include +#include +#include + +namespace fmha { +namespace ws { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Special Softmax struct to handle optimization tricks on Hopper Warp-Specialized Kernels. +template