feat: add MXFP8 fused operators for Wan transformer inference on SM120#1090
feat: add MXFP8 fused operators for Wan transformer inference on SM120#1090Fatemanx wants to merge 1 commit into
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements MXFP8 fused operations for the Wan transformer model, specifically optimized for SM120/SM120a GPUs. The changes include new CUDA kernels for MXFP8 GeLU quantization, modulate quantization, and a fused residual-gate GEMM utilizing CUTLASS, along with corresponding Python wrappers and unit tests. Reviewer feedback highlights several optimization and code quality improvements: moving static parameter device transfers out of the inference loop, consolidating duplicated hardware validation logic into a common utility, replacing std::cerr with idiomatic TORCH_CHECK calls, improving numerical precision by avoiding intermediate rounding in the residual update, and eliminating dynamic tensor allocations in the performance-critical path.
| return self._mxfp8_apply_quantized(module, input_tensor_quant, input_tensor_scale) | ||
|
|
||
| def _mxfp8_apply_quantized(self, module, input_tensor_quant, input_tensor_scale): | ||
| module.alpha = module.alpha.to(module.weight.device) |
There was a problem hiding this comment.
Moving module.alpha to the weight device in every iteration of the inference loop introduces unnecessary Python overhead and potential synchronization points. Since alpha is a static quantization parameter, it should ideally be moved to the correct device once during model initialization. At the very least, check if the device move is necessary before performing it to avoid redundant operations.
if module.alpha.device != module.weight.device:
module.alpha = module.alpha.to(module.weight.device)| inline void check_sm120_or_throw(torch::Tensor const& tensor, char const* op_name) { | ||
| int device = tensor.get_device(); | ||
| check_valid_cuda_device_index(device, op_name); | ||
|
|
||
| static std::array<std::once_flag, kMaxCudaDevices> device_once; | ||
| static std::array<int, kMaxCudaDevices> cached_major{}; | ||
| static std::array<int, kMaxCudaDevices> cached_minor{}; | ||
|
|
||
| std::call_once(device_once[device], [device]() { | ||
| CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&cached_major[device], cudaDevAttrComputeCapabilityMajor, device)); | ||
| CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&cached_minor[device], cudaDevAttrComputeCapabilityMinor, device)); | ||
| }); | ||
|
|
||
| TORCH_CHECK( | ||
| cached_major[device] == 12, | ||
| op_name, | ||
| " is only supported on SM120/SM120a GPUs, got CUDA device ", | ||
| device, | ||
| " with compute capability ", | ||
| cached_major[device], | ||
| ".", | ||
| cached_minor[device]); | ||
| } |
There was a problem hiding this comment.
The check_sm120_or_throw utility is duplicated in both mxfp8_quant_kernels_sm120.cu and mxfp8_scaled_mm_kernels_sm120.cu. Furthermore, the implementations differ (one uses cudaDeviceGetAttribute while the other uses cudaGetDeviceProperties). This function should be moved to a common header (e.g., utils.h) to ensure consistency and reduce code duplication.
| default: { | ||
| std::cerr << "Observing: " << input.scalar_type() << " for the input datatype which is invalid"; | ||
| throw std::runtime_error("Unsupported input data type for gelu_quantize_to_fp8."); | ||
| } |
There was a problem hiding this comment.
Using std::cerr for error reporting in a library is not recommended as it bypasses standard logging mechanisms and can be missed in production environments. Using TORCH_CHECK is more idiomatic in PyTorch extensions and provides a better error message including the tensor type.
default: {
TORCH_CHECK(false, "Unsupported input data type for gelu_quantize_to_fp8: ", input.scalar_type());
}
| float product = __bfloat162float(ffn_out[idx]) * __bfloat162float(gate[gate_idx]); | ||
| __nv_bfloat16 rounded_product = __float2bfloat16(product); | ||
| float sum = __bfloat162float(residual[idx]) + __bfloat162float(rounded_product); | ||
| residual[idx] = __float2bfloat16(sum); |
There was a problem hiding this comment.
The intermediate rounding to __nv_bfloat16 at line 661 is unnecessary and reduces numerical precision. You can perform the addition in float before the final rounding to the output type to maintain higher accuracy during the residual update.
float product = __bfloat162float(ffn_out[idx]) * __bfloat162float(gate[gate_idx]);
float sum = __bfloat162float(residual[idx]) + product;
residual[idx] = __float2bfloat16(sum);
| residual, A, B, A_sf, B_sf, alpha, bias, gate, meta.m, meta.n, meta.k, stream); | ||
| return; | ||
| } | ||
| auto ffn_out = torch::empty_like(residual); |
There was a problem hiding this comment.
Allocating a new tensor using torch::empty_like(residual) in the fallback path for 2D gates will significantly degrade performance during prefill (where M is large). Since this PR aims to optimize the FFN block, you should avoid dynamic allocations in the hot path. Consider using a pre-allocated workspace or allowing the caller to provide an output buffer.
Implement three fused CUDA kernels for MXFP8 quantized inference on Blackwell (SM120): 1. scaled_mxfp8_gelu_quant: fuse GELU activation + E8M0 quantization 2. scaled_mxfp8_modulate_quant: fuse scale/shift modulation + quantization 3. cutlass_scaled_mxfp8_mm_residual_gate: fuse GEMM + residual + gate in CUTLASS 3.x epilogue Performance on RTX 5090 (Wan 5B FFN, m=4096, hidden=1536, ffn=8960): - GELU+Quant: 1.30× faster (27.8μs → 21.3μs) - Modulate+Quant: 3.26× faster (92.7μs → 28.5μs) - GEMM+Residual+Gate: 1.40× faster (194.7μs → 138.9μs) - End-to-end FFN: 1.20× faster (608μs → 505μs, -103μs per block) - Reduces kernel launches from 7 to 3 per FFN block Features: - Supports all Wan tasks (t2v/i2v/flf2v/animate/s2v/rs2v) - Auto-fallback on non-SM120 GPUs (H100/A100/RTX4090) with warning - Handles FP16/BF16 activations (kernel auto-detects dtype) - One-time device capability probe at init (eliminates ~4000 redundant checks per inference) Tested: 10/10 unit tests pass, 6/6 fallback scenarios verified Address review feedback (PR ModelTC#1090): - Skip alpha device move when already on target device - Extract check_sm120_or_throw to shared header sm120_utils.h - Replace std::cerr with TORCH_CHECK in dtype switch fallbacks - Avoid intermediate BF16 round in residual_gate kernel - Apply ruff-format Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
3b4ec5e to
94816cc
Compare
|
Thanks @gemini-code-assist for the thorough review! Addressed all five comments in the latest amended commit:
Verification:
|
|
Thanks for the detailed update and for addressing the review comments. The changes, including the extraction of the SM120 check to a shared header and the replacement of |
Implement three fused CUDA kernels for MXFP8 quantized inference on Blackwell (SM120):
Performance on RTX 5090 (Wan 5B FFN, m=4096, hidden=1536, ffn=8960):
Features:
Tested: 10/10 unit tests pass, 6/6 fallback scenarios verified