GEMM reference computation offload#392
Conversation
44df11e to
e60b912
Compare
e60b912 to
557d580
Compare
557d580 to
ad748da
Compare
There was a problem hiding this comment.
Pull request overview
This PR introduces a GPU-accelerated implementation of the GEMM reference computation using HIP/CUDA kernels to improve performance over the previous CPU-based implementation. The reference computation is critical for validating GEMM operations, and offloading it to the GPU significantly speeds up testing.
Key Changes
- Replaced CPU OpenMP-based reference GEMM computation with GPU kernel implementation
- Introduced
compute_ref_kernelto perform matrix multiplication, bias addition, GELU activation, and scaling on GPU - Refactored both tensor-wise and MXFP8 code paths to use a common
compute_ref_implfunction that manages device memory and kernel execution
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
alextmagro
left a comment
There was a problem hiding this comment.
Hi Matthias! Looks great, just have a couple performance questions
dbf7ae9 to
11e090b
Compare
| fp8e8m0* dB_scale = nullptr; | ||
|
|
||
| // Allocations and H2D transfers | ||
| NVTE_CHECK_CUDA(cudaMalloc(&dA, lenA * sizeof(A_Type))); |
There was a problem hiding this comment.
We can adapt existing test tensor classes (
TransformerEngine/tests/cpp/test_common.cu
Line 226 in 669b556
In fact, we can change the api of reference computing by taking directly const tensor& therefore we don't need to re-allocate the input and do one extra copy
There was a problem hiding this comment.
What do you think of 3ecea7f? This also merges the mxfp8/non-mxfp8 paths.
There was a problem hiding this comment.
Thanks for consolidating with existing apis in test_common.cu.
In fact, I still see some cudaMalloc and cudaFree, which can be replaced by using existing test tensor class apis.
For example, the device pointer for scale (
TransformerEngine/tests/cpp/test_common.cu
Lines 321 to 335 in 2bc74c8
There was a problem hiding this comment.
I replaced the remaining raw allocations in the reference path with test::Tensor for the temporary device buffers (RefD/RefGelu/RefAmax) in e11e400.
There was a problem hiding this comment.
I see. Yeah, it indeed saved some cudaMalloc/cudaFrees.
How about we put the RefD instantiation inside PerformTest, and pass the Tensor RefD (including its RefAmax D) and RefPreGeluOut to run_reference directly (instead of std::unique_ptr<D_Type[]>& ref_D, float* ref_amax_d, std::unique_ptr<Gelu_Type[]>& ref_pre_gelu_out). Then this can save some ref cpu ptr allocation.
This reverts commit 86fbbac.
Description
Introduce a HIP implementation of the GEMM reference computation to speed up these computations.
Partly addresses https://github.com/ROCm/frameworks-internal/issues/14746
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: