CK MXFP8 Group Gemm gfx1250 Enablement#613
Conversation
…rash; remaining issue is numerical validation vs BF16 sequential reference.
…roup-gemm-gfx1250-clean
| if (arch == 94) { | ||
| return GPUArch::GFX942; | ||
| } | ||
| if (arch == 95) { | ||
| return GPUArch::GFX950; | ||
| } | ||
| if (arch == 1250) { | ||
| return GPUArch::GFX1250; | ||
| } |
There was a problem hiding this comment.
Could this be a switch?
There was a problem hiding this comment.
Yeah that looks nicer, thanks. Done in f3ecda3
| if (arch == 95) { | ||
| return GPUArch::GFX950; | ||
| } | ||
| if (arch == 1250) { |
There was a problem hiding this comment.
Yeah I think you're right, thanks. Fixed in f3ecda3
| std::vector<mx_grouped_gemm_kargs> descs; | ||
| descs.reserve(group_num); | ||
|
|
||
| std::vector<std::unique_ptr<ck_tile::DeviceMem>> a_scale_shuffled_bufs; |
There was a problem hiding this comment.
Does ck_tile::DeviceMem allocate new memory? Can we use a workspace here?
There was a problem hiding this comment.
Yes, we can use workspace here. Done in 94b0126
| }; | ||
|
|
||
| template <typename ScaleType, ck_tile::index_t ScaleBlockSize, bool KStride> | ||
| __global__ void preshuffle_scale_gfx1250_kernel(const ScaleType* __restrict__ src, |
There was a problem hiding this comment.
Is this the same shuffling as in #605 ? Maybe we can add a comment here.
There was a problem hiding this comment.
There was a problem hiding this comment.
I think the comment you added goes in the right direction, I would additionally mention what you said here, that this is different from the other mxfp8 gemm swizzling, and that it is expected by CK 1250 WMMA kernel.
alextmagro
left a comment
There was a problem hiding this comment.
Sorry, my review was left as pending, so some of my comments may have already been addressed. Thanks!
| } | ||
|
|
||
| template <typename T> | ||
| static void fill_randn_cpu(Tensor* t, float scale, int seed) { |
There was a problem hiding this comment.
Why not use our hipRAND generator in test_common?
| return cases; | ||
| } | ||
|
|
||
| static const std::vector<CaseConfig> kCases = make_cases(); |
There was a problem hiding this comment.
I think we should probably use seeds generated from test names like the rest of the c++ tests
There was a problem hiding this comment.
Yeah, it should now be consistent in 5b4b7fe
| #pragma once | ||
|
|
||
| #include <hip/hip_runtime.h> | ||
| #include "common/util/cuda_runtime.h" |
There was a problem hiding this comment.
nit: this belongs after common headers
| #include "ck_tile/core.hpp" | ||
| #include "ck_tile/ops/epilogue.hpp" | ||
| #include "ck_tile/ops/gemm.hpp" | ||
| #include "ck_tile/host/kernel_launch.hpp" |
There was a problem hiding this comment.
nit: /host/ goes before /ops/, and /elementwise/ goes before /gemm/
| NVTE_ERROR("ck_tile_mx_grouped_gemm: expected effective A/B scale_inv tensors to be rank-2."); | ||
| } | ||
|
|
||
| const int64_t M = ctx.transA ? Ad1 : Ad0; |
There was a problem hiding this comment.
I think these should be size_ts, unless negative values are needed.
There was a problem hiding this comment.
Yeah that's fair. I changed that in bdc6b4e
| KScale, | ||
| stream); | ||
| } | ||
| descs.emplace_back(mx_grouped_gemm_kargs( |
There was a problem hiding this comment.
Another stylistic comment, but there are lots of line breaks for functions with 1 parameter per line. I personally prefer a more compact style with only line breaks as needed, especially when variable names are relatively short
There was a problem hiding this comment.
Made some additional stylistic changes in bdc6b4e
| ok = invoke_mx_grouped_gemm<GroupedGemKernelParam_Wmma, | ||
| AType, BType, CType, | ||
| AScaleType, BScaleType>(descs,ctx,s); | ||
| }); |
There was a problem hiding this comment.
We need // NOLINT(*) at the end of every TRANSFORMER_ENGINE_TYPE_SWITCH_* statement
| * License for AMD contributions = MIT. See LICENSE for more information | ||
| ************************************************************************/ | ||
|
|
||
| bool ck_tile_mx_grouped_gemm(const NVTETensor* A, |
There was a problem hiding this comment.
Missing #pragma once, and maybe name file .h instead of .hpp for consistency?
There was a problem hiding this comment.
On second thought, can we just add this to ck_grouped_gemm_common.h?
There was a problem hiding this comment.
Good catch. Since ck_grouped_gemm.h was meant to be the public API and ck_grouped_gemm_common.h is internal, I moved the declaration to ck_grouped_gemm.h, removing the need for ck_mx_grouped_gemm.h. Changes made in bdc6b4e
| } | ||
| cublas_path(); | ||
| auto *inputA = transformer_engine::convertNVTETensorCheck(A[0]); | ||
| const bool mxfp8_gemm = transformer_engine::is_mxfp8_scaling(inputA->scaling_mode); |
There was a problem hiding this comment.
Can probably inline this into the if statement since it is only used once
|
|
||
| static constexpr ck_tile::index_t ScaleBlockSize = 32; | ||
|
|
||
| enum struct MxGemmPipelineType |
There was a problem hiding this comment.
I do prefer K&R style, and we lean towards that in the codebase. Consider moving open brackets to same line throughout, and maybe using post-increments and attaching references/pointers to the var instead of the type.
There was a problem hiding this comment.
Thanks for pointing that out. Made the edits in 2e74a63
…ing utilities from test_common.cu
…over existing implementation
| }; | ||
|
|
||
| static inline GPUArch detect_gpu_arch() { | ||
| switch (cuda::sm_arch(0)) { |
There was a problem hiding this comment.
I think you can just use
| switch (cuda::sm_arch(0)) { | |
| switch (cuda::sm_arch()) { |
Description
Integrates CK MXFP8 Group GEMM pipeline into TE.
Fixes https://github.com/ROCm/frameworks-internal/issues/16039
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: