diff --git a/example/15_grouped_gemm/CMakeLists.txt b/example/15_grouped_gemm/CMakeLists.txt index ce41c3310f9..a7dae9dcd82 100644 --- a/example/15_grouped_gemm/CMakeLists.txt +++ b/example/15_grouped_gemm/CMakeLists.txt @@ -44,6 +44,9 @@ add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_spl add_example_executable(example_grouped_gemm_wmma_splitk_bf16 grouped_gemm_wmma_splitk_bf16.cpp) add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_splitk_bf16) +add_example_executable(example_grouped_gemm_multiple_d_wmma_fp16 grouped_gemm_multiple_d_wmma_fp16.cpp) +add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_multiple_d_wmma_fp16) + list(APPEND gpu_list_tf32 gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) diff --git a/example/15_grouped_gemm/grouped_gemm_multiple_d_wmma_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_multiple_d_wmma_fp16.cpp new file mode 100644 index 00000000000..bd58ea433f8 --- /dev/null +++ b/example/15_grouped_gemm/grouped_gemm_multiple_d_wmma_fp16.cpp @@ -0,0 +1,76 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include +#include + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm_multiple_d.hpp" + +using ::ck::DeviceMem; +using ::ck::hip_check_error; +using ::ck::HostTensorDescriptor; +using ::ck::Tensor; + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddAdd = ck::tensor_operation::element_wise::AddAdd; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F16; +using DsDataType = ck::Tuple; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DLayout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddAdd; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; +static constexpr int NumDs = 2; + +using DeviceGemmInstance = + ck::tensor_operation::device::DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3 + // clang-format off +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<4, 4, 4>>; +// clang-format on + +#include "run_grouped_gemm_multiple_d_example.inc" + +int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); } diff --git a/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp index 0e1a38b19a1..9fdcf4aaad2 100644 --- a/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp @@ -71,339 +71,6 @@ using DeviceGemmInstance = < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<4,4,4>>; // clang-format on -struct ProblemSize final -{ - std::vector Ms; - std::vector Ns; - std::vector Ks; +#include "run_grouped_gemm_multiple_d_example.inc" - std::vector stride_As; - std::vector stride_Bs; - std::vector> stride_Ds; - std::vector stride_Cs; - - ck::index_t group_count; -}; - -struct ExecutionConfig final -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; -}; - -bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) -{ - auto group_count = problem_size.group_count; - - using KernelArguments = ck::tensor_operation::device::GroupedGemmKernelArgument; - using GemmDesc = ck::tensor_operation::device::GemmDesc; - - // GEMM shape - std::vector gemm_descs; - std::vector ggemm_kargs; - std::vector p_Cs; - std::vector p_As; - std::vector p_Bs; - std::vector> p_Ds = {}; - - gemm_descs.reserve(group_count); - ggemm_kargs.reserve(group_count); - p_As.reserve(group_count); - p_Bs.reserve(group_count); - p_Ds.reserve(group_count); - - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - using namespace ck::literals; - - if(std::is_same::value) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; - - std::vector> a_tensors; - std::vector> b_tensors; - std::vector, NumDs>> d_tensors; - std::vector> c_host_tensors; - std::vector> c_device_result_tensors; - - a_tensors.reserve(group_count); - b_tensors.reserve(group_count); - d_tensors.reserve(group_count); - c_host_tensors.reserve(group_count); - c_device_result_tensors.reserve(group_count); - - using DeviceMemPtr = std::unique_ptr; - - std::vector a_tensors_device, b_tensors_device, c_tensors_device; - std::vector> d_tensors_device; - - a_tensors_device.reserve(group_count); - b_tensors_device.reserve(group_count); - c_tensors_device.reserve(group_count); - d_tensors_device.resize(group_count); // reserve and update vector size - - std::size_t flop = 0, num_btype = 0; - - for(int i = 0; i < group_count; i++) - { - a_tensors.push_back(Tensor(f_host_tensor_descriptor( - problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{}))); - b_tensors.push_back(Tensor(f_host_tensor_descriptor( - problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{}))); - - auto d0_tensor = Tensor(f_host_tensor_descriptor( - problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{})); - auto d1_tensor = Tensor(f_host_tensor_descriptor( - problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{})); - - std::array, NumDs> d_tens = {d0_tensor, d1_tensor}; - d_tensors.push_back(d_tens); - c_host_tensors.push_back(Tensor(f_host_tensor_descriptor( - problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); - c_device_result_tensors.push_back(Tensor(f_host_tensor_descriptor( - problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); - std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc - << " b_k_n: " << b_tensors[i].mDesc - << " c_m_n: " << c_device_result_tensors[i].mDesc << std::endl; - - flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i]; - num_btype += sizeof(ADataType) * a_tensors[i].GetElementSize() + - sizeof(BDataType) * b_tensors[i].GetElementSize() + - sizeof(DDataType) * d_tensors[i][0].GetElementSize() * NumDs + - sizeof(EDataType) * c_device_result_tensors[i].GetElementSize(); - - switch(config.init_method) - { - case 0: break; - case 1: - a_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); - for(int j = 0; j < NumDs; ++j) - { - d_tensors[i][j].GenerateTensorValue(GeneratorTensor_2{-5, 5}); - } - break; - case 2: - a_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - for(int j = 0; j < NumDs; ++j) - { - d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - } - break; - default: - a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); - b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); - for(int j = 0; j < NumDs; ++j) - { - d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential{}); - } - } - } - - for(int i = 0; i < group_count; i++) - { - a_tensors_device.emplace_back( - std::make_unique(a_tensors[i].GetElementSpaceSize() * sizeof(ADataType))); - b_tensors_device.emplace_back( - std::make_unique(b_tensors[i].GetElementSpaceSize() * sizeof(BDataType))); - c_tensors_device.emplace_back(std::make_unique( - c_device_result_tensors[i].GetElementSpaceSize() * sizeof(EDataType))); - - for(int j = 0; j < NumDs; ++j) - { - d_tensors_device[i].emplace_back(std::make_unique( - d_tensors[i][j].GetElementSpaceSize() * sizeof(DDataType))); - } - - a_tensors_device[i]->ToDevice(a_tensors[i].mData.data()); - b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); - for(int j = 0; j < NumDs; ++j) - { - d_tensors_device[i][j]->ToDevice(d_tensors[i][j].mData.data()); - } - c_tensors_device[i]->SetZero(); - - p_As.push_back(a_tensors_device[i]->GetDeviceBuffer()); - p_Bs.push_back(b_tensors_device[i]->GetDeviceBuffer()); - p_Ds.push_back( - {d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()}); - p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer()); - - // The device op does not have to know M problem size at lunch time. - gemm_descs.push_back({0, - problem_size.Ns[i], - problem_size.Ks[i], - problem_size.stride_As[i], - problem_size.stride_Bs[i], - problem_size.stride_Cs[i], - {problem_size.stride_Cs[i], problem_size.stride_Cs[i]}}); - ggemm_kargs.push_back( - {a_tensors_device[i]->GetDeviceBuffer(), - b_tensors_device[i]->GetDeviceBuffer(), - {d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()}, - c_tensors_device[i]->GetDeviceBuffer(), - problem_size.Ms[i], - problem_size.Ns[i], - problem_size.Ks[i], - problem_size.stride_As[i], - problem_size.stride_Bs[i], - {problem_size.stride_Cs[i], problem_size.stride_Cs[i]}, - problem_size.stride_Cs[i]}); - } - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto cde_element_op = CDEElementOp{}; - - auto gemm = DeviceGemmInstance{}; - auto invoker = gemm.MakeInvoker(); - - // do GEMM - auto argument = gemm.MakeArgument( - p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op); - if(!gemm.IsSupportedArgument(argument)) - { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); - } - - DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument)); - hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(), - ggemm_kargs.data(), - gemm.GetDeviceKernelArgSize(&argument), - hipMemcpyHostToDevice)); - gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer()); - - invoker.Run(argument, StreamConfig{nullptr, false, 1}); - - bool pass = true; - if(config.do_verification) - { - using ReferenceGemmInstance = - ck::tensor_operation::host::ReferenceGemmMultipleD; - - for(std::size_t i = 0; i < gemm_descs.size(); i++) - { - auto karg = ggemm_kargs[i]; - auto dev_res_tensor = - Tensor(f_host_tensor_descriptor(karg.M, karg.N, karg.StrideE, ELayout{})); - c_tensors_device[i]->FromDevice(c_device_result_tensors[i].mData.data()); - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = ref_gemm.MakeArgument(a_tensors[i], - b_tensors[i], - d_tensors[i], - c_host_tensors[i], - a_element_op, - b_element_op, - cde_element_op); - - ref_invoker.Run(ref_argument); - pass &= ck::utils::check_err(c_device_result_tensors[i], c_host_tensors[i]); - } - - std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl; - } - - if(config.time_kernel) - { - float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); - float tflops = static_cast(flop) / 1.E9 / ave_time; - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec - << " GB/s, " << gemm.GetTypeString() << std::endl; - } - - return pass; -} - -std::vector argToIntArray(char* input) -{ - std::vector out; - std::istringstream in(input); - std::string item; - - while(std::getline(in, item, ',')) - { - out.push_back(std::stoi(item)); - } - return out; -} - -int main(int argc, char* argv[]) -{ - ProblemSize problem_size; - ExecutionConfig config; - - if(argc < 10) - { - std::vector Ms{64, 127, 255, 129, 260, 190, 77}; - problem_size.group_count = Ms.size(); - - for(int i = 0; i < problem_size.group_count; i++) - { - problem_size.Ms.push_back(Ms[i]); - problem_size.Ns.push_back(252); - problem_size.Ks.push_back(4608); - - problem_size.stride_As.push_back(problem_size.Ks[i]); - problem_size.stride_Bs.push_back(problem_size.Ks[i]); - problem_size.stride_Cs.push_back(problem_size.Ns[i]); - - problem_size.stride_Ds.push_back({}); - for(int j = 0; j < NumDs; ++j) - { - problem_size.stride_Ds[i].push_back(problem_size.Ns[i]); - } - } - - std::cout - << "Usage:\n" - << "arg1: verification (0=no, 1=yes)\n" - << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" - << "arg3: time kernel (0=n0, 1=yes)\n" - << "arg4 to 9: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 " - "64,64 64,64 128,128)\n" - << "... setting default values." << std::endl; - } - else - { - config.do_verification = std::stoi(argv[1]); - config.init_method = std::stoi(argv[2]); - config.time_kernel = std::stoi(argv[3]); - - problem_size.Ms = argToIntArray(argv[4]); - problem_size.Ns = argToIntArray(argv[5]); - problem_size.Ks = argToIntArray(argv[6]); - - problem_size.stride_As = argToIntArray(argv[7]); - problem_size.stride_Bs = argToIntArray(argv[8]); - problem_size.stride_Cs = argToIntArray(argv[9]); - - for(int j = 0; j < NumDs; ++j) - { - problem_size.stride_Ds.push_back(problem_size.stride_Cs); - } - - problem_size.group_count = problem_size.Ms.size(); - } - - return !run_grouped_gemm(problem_size, config); -} +int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); } diff --git a/example/15_grouped_gemm/grouped_gemm_wmma_splitk_bf16.cpp b/example/15_grouped_gemm/grouped_gemm_wmma_splitk_bf16.cpp index e4da397c230..e942aad1c16 100644 --- a/example/15_grouped_gemm/grouped_gemm_wmma_splitk_bf16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_wmma_splitk_bf16.cpp @@ -58,11 +58,11 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_CShuffleV3 // clang-format off -//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| -//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| -//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>; +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>; // clang-format on diff --git a/example/15_grouped_gemm/grouped_gemm_wmma_splitk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_wmma_splitk_fp16.cpp index d5b22058920..fb3a6f0b4fa 100644 --- a/example/15_grouped_gemm/grouped_gemm_wmma_splitk_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_wmma_splitk_fp16.cpp @@ -57,11 +57,11 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_CShuffleV3 // clang-format off -//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| -//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| -//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>; +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>; // clang-format on diff --git a/example/15_grouped_gemm/run_grouped_gemm_example.inc b/example/15_grouped_gemm/run_grouped_gemm_example.inc index 764b533455b..ffd0c5e9b7b 100644 --- a/example/15_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/15_grouped_gemm/run_grouped_gemm_example.inc @@ -323,8 +323,8 @@ bool run_grouped_gemm_example(int argc, char* argv[]) { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=n0, 1=yes)\n"); - printf("arg4: async hargs (0=n0, 1=yes)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4: async hargs (0=no, 1=yes)\n"); printf("arg5: group count (default=16)\n"); #if defined(EXAMPLE_USE_SPLITK) printf("arg6: k-batch count (default=1)\n"); diff --git a/example/15_grouped_gemm/run_grouped_gemm_multiple_d_example.inc b/example/15_grouped_gemm/run_grouped_gemm_multiple_d_example.inc new file mode 100644 index 00000000000..a71a23ab797 --- /dev/null +++ b/example/15_grouped_gemm/run_grouped_gemm_multiple_d_example.inc @@ -0,0 +1,341 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +struct ProblemSize final +{ + std::vector Ms; + std::vector Ns; + std::vector Ks; + + std::vector stride_As; + std::vector stride_Bs; + std::vector> stride_Ds; + std::vector stride_Cs; + + ck::index_t group_count; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; +}; + +bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + auto group_count = problem_size.group_count; + + using KernelArguments = ck::tensor_operation::device::GroupedGemmKernelArgument; + using GemmDesc = ck::tensor_operation::device::GemmDesc; + + // GEMM shape + std::vector gemm_descs; + std::vector ggemm_kargs; + std::vector p_Cs; + std::vector p_As; + std::vector p_Bs; + std::vector> p_Ds = {}; + + gemm_descs.reserve(group_count); + ggemm_kargs.reserve(group_count); + p_As.reserve(group_count); + p_Bs.reserve(group_count); + p_Ds.reserve(group_count); + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + std::vector> a_tensors; + std::vector> b_tensors; + std::vector, NumDs>> d_tensors; + std::vector> c_host_tensors; + std::vector> c_device_result_tensors; + + a_tensors.reserve(group_count); + b_tensors.reserve(group_count); + d_tensors.reserve(group_count); + c_host_tensors.reserve(group_count); + c_device_result_tensors.reserve(group_count); + + using DeviceMemPtr = std::unique_ptr; + + std::vector a_tensors_device, b_tensors_device, c_tensors_device; + std::vector> d_tensors_device; + + a_tensors_device.reserve(group_count); + b_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + d_tensors_device.resize(group_count); // reserve and update vector size + + std::size_t flop = 0, num_btype = 0; + + for(int i = 0; i < group_count; i++) + { + a_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{}))); + b_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{}))); + + auto d0_tensor = Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{})); + auto d1_tensor = Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{})); + + std::array, NumDs> d_tens = {d0_tensor, d1_tensor}; + d_tensors.push_back(d_tens); + c_host_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + c_device_result_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc + << " b_k_n: " << b_tensors[i].mDesc + << " c_m_n: " << c_device_result_tensors[i].mDesc << std::endl; + + flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i]; + num_btype += sizeof(ADataType) * a_tensors[i].GetElementSize() + + sizeof(BDataType) * b_tensors[i].GetElementSize() + + sizeof(DDataType) * d_tensors[i][0].GetElementSize() * NumDs + + sizeof(EDataType) * c_device_result_tensors[i].GetElementSize(); + + switch(config.init_method) + { + case 0: break; + case 1: + a_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + for(int j = 0; j < NumDs; ++j) + { + d_tensors[i][j].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + } + break; + case 2: + a_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + for(int j = 0; j < NumDs; ++j) + { + d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + break; + default: + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + for(int j = 0; j < NumDs; ++j) + { + d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential{}); + } + } + } + + for(int i = 0; i < group_count; i++) + { + a_tensors_device.emplace_back( + std::make_unique(a_tensors[i].GetElementSpaceSize() * sizeof(ADataType))); + b_tensors_device.emplace_back( + std::make_unique(b_tensors[i].GetElementSpaceSize() * sizeof(BDataType))); + c_tensors_device.emplace_back(std::make_unique( + c_device_result_tensors[i].GetElementSpaceSize() * sizeof(EDataType))); + + for(int j = 0; j < NumDs; ++j) + { + d_tensors_device[i].emplace_back(std::make_unique( + d_tensors[i][j].GetElementSpaceSize() * sizeof(DDataType))); + } + + a_tensors_device[i]->ToDevice(a_tensors[i].mData.data()); + b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); + for(int j = 0; j < NumDs; ++j) + { + d_tensors_device[i][j]->ToDevice(d_tensors[i][j].mData.data()); + } + c_tensors_device[i]->SetZero(); + + p_As.push_back(a_tensors_device[i]->GetDeviceBuffer()); + p_Bs.push_back(b_tensors_device[i]->GetDeviceBuffer()); + p_Ds.push_back( + {d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()}); + p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer()); + + // The device op does not have to know M problem size at lunch time. + gemm_descs.push_back({0, + problem_size.Ns[i], + problem_size.Ks[i], + problem_size.stride_As[i], + problem_size.stride_Bs[i], + problem_size.stride_Cs[i], + {problem_size.stride_Cs[i], problem_size.stride_Cs[i]}}); + ggemm_kargs.push_back( + {a_tensors_device[i]->GetDeviceBuffer(), + b_tensors_device[i]->GetDeviceBuffer(), + {d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()}, + c_tensors_device[i]->GetDeviceBuffer(), + problem_size.Ms[i], + problem_size.Ns[i], + problem_size.Ks[i], + problem_size.stride_As[i], + problem_size.stride_Bs[i], + {problem_size.stride_Cs[i], problem_size.stride_Cs[i]}, + problem_size.stride_Cs[i]}); + } + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + // do GEMM + auto argument = gemm.MakeArgument( + p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op); + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument)); + hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(), + ggemm_kargs.data(), + gemm.GetDeviceKernelArgSize(&argument), + hipMemcpyHostToDevice)); + gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer()); + + invoker.Run(argument, StreamConfig{nullptr, false, 1}); + + bool pass = true; + if(config.do_verification) + { + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceGemmMultipleD; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + auto karg = ggemm_kargs[i]; + auto dev_res_tensor = + Tensor(f_host_tensor_descriptor(karg.M, karg.N, karg.StrideE, ELayout{})); + c_tensors_device[i]->FromDevice(c_device_result_tensors[i].mData.data()); + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_tensors[i], + b_tensors[i], + d_tensors[i], + c_host_tensors[i], + a_element_op, + b_element_op, + cde_element_op); + + ref_invoker.Run(ref_argument); + pass &= ck::utils::check_err(c_device_result_tensors[i], c_host_tensors[i]); + } + + std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl; + } + + if(config.time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + + return pass; +} + +std::vector argToIntArray(char* input) +{ + std::vector out; + std::istringstream in(input); + std::string item; + + while(std::getline(in, item, ',')) + { + out.push_back(std::stoi(item)); + } + return out; +} + +bool run_grouped_gemm_example(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + if(argc < 10) + { + std::vector Ms{64, 127, 255, 129, 260, 190, 77}; + problem_size.group_count = Ms.size(); + + for(int i = 0; i < problem_size.group_count; i++) + { + problem_size.Ms.push_back(Ms[i]); + problem_size.Ns.push_back(252); + problem_size.Ks.push_back(4608); + + problem_size.stride_As.push_back(problem_size.Ks[i]); + problem_size.stride_Bs.push_back(problem_size.Ks[i]); + problem_size.stride_Cs.push_back(problem_size.Ns[i]); + + problem_size.stride_Ds.push_back({}); + for(int j = 0; j < NumDs; ++j) + { + problem_size.stride_Ds[i].push_back(problem_size.Ns[i]); + } + } + + std::cout + << "Usage:\n" + << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=n0, 1=yes)\n" + << "arg4 to 9: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 " + "64,64 64,64 128,128)\n" + << "... setting default values." << std::endl; + } + else + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + + problem_size.Ms = argToIntArray(argv[4]); + problem_size.Ns = argToIntArray(argv[5]); + problem_size.Ks = argToIntArray(argv[6]); + + problem_size.stride_As = argToIntArray(argv[7]); + problem_size.stride_Bs = argToIntArray(argv[8]); + problem_size.stride_Cs = argToIntArray(argv[9]); + + for(int j = 0; j < NumDs; ++j) + { + problem_size.stride_Ds.push_back(problem_size.stride_Cs); + } + + problem_size.group_count = problem_size.Ms.size(); + } + + return run_grouped_gemm(problem_size, config); +} diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp index 3b12e7feb00..4f884b1df3a 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp @@ -151,7 +151,10 @@ struct BlockwiseGemmWmmaops_pipeline_v1 PrefetchStages; } + static bool __host__ __device__ BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } static TailNumber BlockLoopTailNum(index_t num_loop) { @@ -707,7 +710,10 @@ struct BlockwiseGemmWmmaops_pipeline_v1 PrefetchStages; } + __host__ __device__ static bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } static TailNumber BlockLoopTailNum(index_t num_loop) { diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp index ade80358776..2154f358150 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp @@ -3,6 +3,11 @@ #pragma once +#include "ck/ck.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/stream_utility.hpp" + #include "device_grouped_gemm.hpp" namespace ck { @@ -43,6 +48,59 @@ struct DeviceGroupedGemmTileLoop : public DeviceGroupedGemm +struct TileLoopKernelConfig +{ + // The oversubscription factor for the number of blocks that can simultaneously reside on + // GPU. + static constexpr int BLOCK_SUBSCRIPTION_FACTOR = 1; + // static constexpr int BLOCK_WAVES = BlockSize / get_warp_size(); + static constexpr int CU_SIMDS = 4; + // Assume we want to have at most 2 waves per SIMD + // static constexpr int CU_BLOCKS = math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES); + static int GetCuBlocks() + { + int BLOCK_WAVES = BlockSize / get_warp_size(); + return ck::math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES); + } + + template + static int CalculateMaxOccupancyGridSize(const KernelFunction& kernel, + const StreamConfig& stream_config) + { + // Calculate max number of workgroups that can simultaneously reside on the CU. + int occ_num_blocks = GetKernelOccupancy(kernel); + int cu_count = getAvailableComputeUnitCount(stream_config); + + if(stream_config.log_level_ > 0) + { + std::cout << "MaxActiveBlocksPerCU: " << occ_num_blocks + << ", available CUs count: " << cu_count << ", occup. grid size: " + << ck::math::min(occ_num_blocks, GetCuBlocks()) * cu_count << std::endl; + } + + return cu_count * ck::math::min(occ_num_blocks, GetCuBlocks()); + } + + template + static int GetKernelOccupancy(const KernelFunction& kernel) + { + int occupancy = 0; + ck::hip_check_error( + hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0)); + return occupancy; + } + + static int GetComputeUnitCount() + { + hipDeviceProp_t dev_prop; + hipDevice_t dev; + ck::hip_check_error(hipGetDevice(&dev)); + ck::hip_check_error(hipGetDeviceProperties(&dev_prop, dev)); + return dev_prop.multiProcessorCount; + } +}; + } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp new file mode 100644 index 00000000000..b7c0d89e0f3 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp @@ -0,0 +1,689 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include + +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/hip_check_error.hpp" +#include "ck/host_utility/stream_utility.hpp" +#include "ck/utility/loop_scheduler.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/// +/// @brief Entry point kernel for device-wide Grouped GEMM operation. +/// +/// @param[in] gemm_descs_const The pointer to the array of GEMM descriptor structures. +/// @param[in] group_count The number of together processed GEMMs. +/// +/// @tparam GridwiseGemm The specific GridwiseGEMM algorithm implementation. +/// @tparam GemmDesc The structure holding all necessary descriptors and +/// other data needed for grouped gemm calculation and work +/// distribution. +/// @tparam LocalBlock2ETileMap The structure providing mapping between workgroup ids, +/// the data tiles to process and the output tiles. +/// +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_gemm_multiple_d_wmma(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + const index_t group_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op) +{ +#if(defined(__gfx11__) || defined(__gfx12__)) + constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< + typename GridwiseGemm::EpilogueCShuffle>(); + __shared__ uint8_t p_shared[LDS_size]; + + const auto gemm_desc_ptr = + reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); + + constexpr auto NumDTensor = DsDataType::Size(); + index_t tile_id = get_block_1d_id(); + index_t tile_offset = 0; + index_t group_id = -1; + index_t group_offset = 0; + index_t grid_size_grp = 0; + + index_t gemm_tile_id_start = 0; + index_t gemm_tile_id_end = 0; + + index_t M = 0, N = 0, K = 0; + + auto b2c_tile_map = OffsettedBlockToCTileMap(LocalBlock2ETileMap(1, 1), 1, 1); + + do + { + // Find corresponding GEMM group for our tile + while(!(tile_id >= gemm_tile_id_start && tile_id < gemm_tile_id_end) && + group_id < group_count) + { + group_offset += grid_size_grp; + group_id++; + + if(group_id >= group_count) + return; + + M = gemm_desc_ptr[group_id].M; + N = gemm_desc_ptr[group_id].N; + K = gemm_desc_ptr[group_id].K; + + if(M == 0 || N == 0 || K == 0) + { + grid_size_grp = 0; + continue; + } + + b2c_tile_map = + OffsettedBlockToCTileMap(LocalBlock2ETileMap(M, N, 4), group_offset, tile_offset); + grid_size_grp = b2c_tile_map.CalculateGridSize(M, N); + + gemm_tile_id_start = group_offset; + gemm_tile_id_end = group_offset + grid_size_grp; + } + + // Create A&B grid pointer containing their single tensors + typename GridwiseGemm::AsGridPointer p_as_grid = Tuple( + static_cast(gemm_desc_ptr[group_id].p_a_grid)); + typename GridwiseGemm::BsGridPointer p_bs_grid = Tuple( + static_cast(gemm_desc_ptr[group_id].p_b_grid)); + + // Make a DsGridPointer instance containing all D tensors + using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer()); + DsGridPointer p_ds_grid; + std::array stride_Ds; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + p_ds_grid(i) = static_cast(gemm_desc_ptr[group_id].p_ds_grid[i]); + stride_Ds[i] = gemm_desc_ptr[group_id].StrideDs[i]; + }); + + index_t K_split = ck::math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + // Update tile offset if we have moved within group + b2c_tile_map.UpdateTileOffset(tile_offset); + + using Problem = typename GridwiseGemm::Problem; + auto problem = Problem(gemm_desc_ptr[group_id].M, + gemm_desc_ptr[group_id].N, + gemm_desc_ptr[group_id].K, + std::array{gemm_desc_ptr[group_id].StrideA}, + std::array{gemm_desc_ptr[group_id].StrideB}, + stride_Ds, + gemm_desc_ptr[group_id].StrideE, + 1); + + auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + constexpr TailNumber TailNum = TailNumber::Full; + + if(has_main_k_block_loop) + { + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + GridwiseGemm::template Run( + p_as_grid, + p_bs_grid, + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + b2c_tile_map, + a_element_op, + b_element_op, + cde_element_op, + epilogue_args); + } + } + else + { + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + GridwiseGemm::template Run( + p_as_grid, + p_bs_grid, + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + b2c_tile_map, + a_element_op, + b_element_op, + cde_element_op, + epilogue_args); + } + } + + tile_id += get_grid_size(); + tile_offset += get_grid_size(); + + } while(group_id < group_count); +#else + ignore = gemm_descs_const; + ignore = group_count; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; +#endif // end of if (defined(__gfx11__) || defined(__gfx12__)) +} + +template + +struct DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3 + : public DeviceGroupedGemmTileLoop +{ + using DeviceOp = DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< + ALayout, + BLayout, + DsLayout, + ELayout, + Tuple, + Tuple, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + false, // PermuteA not supported by GridwiseOp. + false>; // PermuteB not supported by DeviceGroupedGemmTileLoop base class. + + using KernelConfig = TileLoopKernelConfig; + using KernelArguments = GroupedGemmKernelArgument; + using Block2ETileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + using OffsettedLocalBlock2ETileMap = OffsettedBlockToCTileMap2; + + // Argument + struct Argument : public BaseArgument + { + Argument(std::vector& /* p_As */, + std::vector& /* p_Bs */, + std::vector>& /* p_Ds */, + std::vector& /* p_Es */, + const std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + int occupancy_num_blocks, + int gpu_cu_count) + : group_count_{static_cast(gemm_descs.size())}, + occupancy_num_blocks_{occupancy_num_blocks}, + gpu_cu_count_{gpu_cu_count}, + gemm_descs_{gemm_descs}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + tile_count_{0} + { + for(const auto& desc : gemm_descs) + { + const auto M = desc.M_; + const auto N = desc.N_; + const auto b2c_tile_map = Block2ETileMap(M, N); + tile_count_ += b2c_tile_map.CalculateGridSize(M, N); + } + } + + index_t group_count_; + const void* p_dev_gemm_args_; + int occupancy_num_blocks_; + int gpu_cu_count_; + const std::vector& gemm_descs_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + index_t tile_count_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + /// + /// @brief Launch Grouped Gemm kernel. + /// + /// @note This function overload is using user provided device buffer for kernel + /// arguments. + /// + /// @param[in] arg The structure containing kernel arguments (in host + /// memory). + /// @param[in] dev_gemm_args The pointer to device memory with kernel arguments. + /// @param[in] stream_config The device stream configuration. + /// + /// @return The average kernel execution time (if time measurement is enabled.) + /// + float Run(const Argument& arg, + const void* dev_gemm_args, + const StreamConfig& stream_config = StreamConfig{}) + { + if(dev_gemm_args == nullptr) + { + std::ostringstream err; + err << "The gemm arguments device buffer is not allocated!" << " In " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + const auto kernel = GetKernelFunction(); + + int grid_size = KernelConfig::CalculateMaxOccupancyGridSize(kernel, stream_config); + + if(stream_config.log_level_ > 0) + { + std::cout << "grid_size: " << grid_size << " tile_count: " << arg.tile_count_ + << std::endl; + } + + // run multiple kernels + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(dev_gemm_args), + arg.group_count_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_); + } + + /// + /// @brief Launch Grouped Gemm kernel. + /// + /// @note This function overload is using device buffers (for kernel arguments and + /// for kernel auxiliary workspace) provided with an argument. The user should + /// call @see GetDeviceKernelArgSize, and @see SetDeviceKernelArgs, on arg + /// parameter to properly allocate those buffers. + /// + /// @param[in] arg The structure containing kernel arguments (in host memory). + /// @param[in] stream_config The device stream configuration. + /// + /// @return The average kernel execution time (if time measurement is enabled.) + /// + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(arg.p_dev_gemm_args_ == nullptr) + { + std::ostringstream err; + err << "The gemm arguments device buffer is not allocated!" << " In " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + return Run(arg, arg.p_dev_gemm_args_, stream_config); + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static auto GetKernelFunction() + { + const auto kernel = kernel_grouped_gemm_multiple_d_wmma; + return kernel; + } + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + return false; + } + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + { + if(ck::is_gfx11_supported()) + { + return false; + } + } + + bool supported = true; + for(index_t i = 0; i < arg.group_count_; ++i) + { + std::array placeholder_p_ds_grid{}; + std::array stride_Ds; + std::copy_n(arg.gemm_descs_[i].stride_Ds_.begin(), NumDTensor, stride_Ds.begin()); + + typename GridwiseGemm::Argument gridwise_arg( + std::array{nullptr}, // p_a_grid, + std::array{nullptr}, // p_b_grid, + placeholder_p_ds_grid, // p_ds_grid, + nullptr, // p_e_grid , + arg.gemm_descs_[i].M_, + arg.gemm_descs_[i].N_, + arg.gemm_descs_[i].K_, + std::array{arg.gemm_descs_[i].stride_A_}, + std::array{arg.gemm_descs_[i].stride_B_}, + stride_Ds, + arg.gemm_descs_[i].stride_C_, + 1, // KBatch + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + false); + + bool group_arg_valid = GridwiseGemm::CheckValidity(gridwise_arg); + supported = supported && group_arg_valid; + + if(!group_arg_valid) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[" << __func__ << "] group id: " << i + << " has invalid GridwiseGemm settings!" << std::endl; + gridwise_arg.Print(); + } + } + } + + return supported; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static int GetKernelOccupancy() + { + const auto kernel = GetKernelFunction(); + return KernelConfig::GetKernelOccupancy(kernel); + } + + static auto MakeArgument(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_elementwise_op, + BElementwiseOperation b_elementwise_op, + CDEElementwiseOperation cde_elementwise_op) + { + int occupancy = GetKernelOccupancy(); + int num_cu = KernelConfig::GetComputeUnitCount(); + + return Argument{p_As, + p_Bs, + p_Ds, + p_Es, + gemm_descs, + a_elementwise_op, + b_elementwise_op, + cde_elementwise_op, + occupancy, + num_cu}; + } + + std::unique_ptr + MakeArgumentPointer(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_elementwise_op, + BElementwiseOperation b_elementwise_op, + CDEElementwiseOperation cde_elementwise_op) override + { + int occupancy = GetKernelOccupancy(); + int num_cu = KernelConfig::GetComputeUnitCount(); + + return std::make_unique(p_As, + p_Bs, + p_Ds, + p_Es, + gemm_descs, + a_elementwise_op, + b_elementwise_op, + cde_elementwise_op, + occupancy, + num_cu); + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::ostringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3" + << "<" + << std::string(ALayout::name)[0] << "," + << std::string(BLayout::name)[0] << "," + << std::string(ELayout::name)[0] << "," + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerWmma << ", " + << NPerWmma << ", " + << MRepeat << ", " + << NRepeat << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMRepeatPerShuffle << ", " + << CShuffleNRepeatPerShuffle << ", " + << getGemmSpecializationString(GemmSpec) << ", " + << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", " + << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] + << ">"; + // clang-format on + + return str.str(); + } + + void SetDeviceKernelArgs(Argument& arg, + void* p_dev_kernel_args, + const void* p_host_kernel_args) const + { + arg.p_dev_gemm_args_ = p_dev_kernel_args; + hip_check_error(hipMemcpyAsync(p_dev_kernel_args, + p_host_kernel_args, + GetDeviceKernelArgSize(&arg), + hipMemcpyHostToDevice)); + } + + virtual void SetDeviceKernelArgs(BaseArgument* p_arg, + void* p_dev_kernel_args, + const void* p_host_kernel_args) const override + { + return SetDeviceKernelArgs( + *dynamic_cast(p_arg), p_dev_kernel_args, p_host_kernel_args); + } + + void SetDeviceKernelArgs(Argument& arg, void* p_dev_kernel_args) const + { + arg.p_dev_gemm_args_ = p_dev_kernel_args; + } + + virtual void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override + { + return SetDeviceKernelArgs(*dynamic_cast(p_arg), p_dev_kernel_args); + } + + size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override + { + return dynamic_cast(p_arg)->group_count_ * sizeof(KernelArguments); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp index 4492e6474fc..a9e81f55631 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp @@ -4,6 +4,7 @@ #pragma once #include +#include #include #include @@ -26,6 +27,18 @@ namespace ck { namespace tensor_operation { namespace device { +// Dummy kernel to use as a fallback in the kernel selection logic +// Is not used in practice, but only used in case of misconfigured parameters +template +__global__ void kernel_dummy(const void CK_CONSTANT_ADDRESS_SPACE*, + const index_t, + const AElementwiseOperation, + const BElementwiseOperation, + const CDEElementwiseOperation) +{ +} /// /// @brief Entry point kernel for device-wide Grouped GEMM operation. /// @@ -528,6 +541,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop using GridwiseGemm64 = GridwiseGemmBase; using GridwiseGemm32 = GridwiseGemmBase; + using KernelConfig = TileLoopKernelConfig; using KernelArguments = GroupedGemmKernelArgument; using Block2ETileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; using OffsettedLocalBlock2ETileMap = OffsettedBlockToCTileMap2; @@ -574,22 +588,6 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop index_t tile_count_; }; - struct KernelConfig - { - // The oversubscription factor for the number of blocks that can simultaneously reside on - // GPU. - static constexpr int BLOCK_SUBSCRIPTION_FACTOR = 1; - // static constexpr int BLOCK_WAVES = BlockSize / get_warp_size(); - static constexpr int CU_SIMDS = 4; - // Assume we want to have at most 2 waves per SIMD - // static constexpr int CU_BLOCKS = math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES); - static int GetCuBlocks() - { - int BLOCK_WAVES = BlockSize / get_warp_size(); - return math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES); - } - }; - // Invoker struct Invoker : public BaseInvoker { @@ -666,58 +664,17 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop const void* dev_gemm_args, const StreamConfig& stream_config) const { - const auto kernel = kernel_grouped_gemm_multiple_d_xdl; + const auto kernel = GetKernelFunction(); return LaunchKernel(kernel, arg, dev_gemm_args, stream_config); } - template - int CalculateMaxOccupancyGridSize(const KernelFunction& kernel, - const StreamConfig& stream_config) const - { - // Calculate max number of workgroups that can simultaneously reside on the CU. - int occ_num_blocks = 0; - size_t dyn_shared_mem_per_blk = 0; - hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( - &occ_num_blocks, kernel, BlockSize, dyn_shared_mem_per_blk)); - - int cu_count = getAvailableComputeUnitCount(stream_config); - - if(stream_config.log_level_ > 0) - { - std::cout << "MaxActiveBlocksPerCU: " << occ_num_blocks - << ", available CUs count: " << cu_count << ", occup. grid size: " - << ck::math::min(occ_num_blocks, KernelConfig::GetCuBlocks()) * cu_count - << std::endl; - } - - return cu_count * ck::math::min(occ_num_blocks, KernelConfig::GetCuBlocks()); - } - template float LaunchKernel(const KernelFunction& kernel, const Argument& arg, const void* dev_gemm_args, const StreamConfig& stream_config) const { - int grid_size = CalculateMaxOccupancyGridSize(kernel, stream_config); + int grid_size = KernelConfig::CalculateMaxOccupancyGridSize(kernel, stream_config); if(stream_config.log_level_ > 0) { @@ -835,65 +792,60 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop return IsSupportedArgument(*dynamic_cast(p_arg)); } - static int GetKernelOccupancy() + template + static auto GetKernelFunction() + { + const auto kernel = kernel_grouped_gemm_multiple_d_xdl; + return kernel; + } + + static auto GetKernelFunction() { - int occupancy = 0; if(get_warp_size() == 64) { if constexpr(NXdlPerWave64 > 0) { - const auto kernel = kernel_grouped_gemm_multiple_d_xdl; - hip_check_error( - hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0)); + const auto kernel = GetKernelFunction(); + return kernel; } } else { - if constexpr(NXdlPerWave32 > 0) { - const auto kernel = kernel_grouped_gemm_multiple_d_xdl; - hip_check_error( - hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0)); + const auto kernel = GetKernelFunction(); + return kernel; } } - return occupancy; + + // This is here to handle the case where MXdlPerWave/NxdPerWave is too small + // This is caught by IsSupportedArgument(), but as GetKernelFunction is sometimes called + // before we need a fallback kernel to return here. + return kernel_dummy; + } + + static int GetKernelOccupancy() + { + const auto kernel = GetKernelFunction(); + return KernelConfig::GetKernelOccupancy(kernel); } static auto MakeArgument(std::vector& p_As, @@ -906,13 +858,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop CDEElementwiseOperation cde_elementwise_op) { int occupancy = GetKernelOccupancy(); - int num_cu; - - hipDeviceProp_t dev_prop; - hipDevice_t dev; - hip_check_error(hipGetDevice(&dev)); - hip_check_error(hipGetDeviceProperties(&dev_prop, dev)); - num_cu = dev_prop.multiProcessorCount; + int num_cu = KernelConfig::GetComputeUnitCount(); return Argument{p_As, p_Bs, @@ -937,13 +883,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop CDEElementwiseOperation cde_elementwise_op) override { int occupancy = GetKernelOccupancy(); - int num_cu; - - hipDeviceProp_t dev_prop; - hipDevice_t dev; - hip_check_error(hipGetDevice(&dev)); - hip_check_error(hipGetDeviceProperties(&dev_prop, dev)); - num_cu = dev_prop.multiProcessorCount; + int num_cu = KernelConfig::GetComputeUnitCount(); return std::make_unique(p_As, p_Bs, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp index 6914def1101..714d5670204 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp @@ -126,7 +126,6 @@ template + typename ComputeTypeB = ComputeTypeA> struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK; // PermuteB not supported by DeviceBatchedGemm base class. + false, // PermuteA not supported by GridwiseOp + false>; // PermuteB not supported by DeviceGroupedGemm base class using CGridDesc_M_N = remove_cvref_t( @@ -779,7 +776,7 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK(c) * d0 + d1; - e = y; + const half_t y = + type_convert(c * type_convert(d0) + type_convert(d1)); + e = y; } template <> __host__ __device__ void operator()(bhalf_t& e, @@ -245,8 +247,9 @@ struct MultiplyAdd const bhalf_t& d0, const bhalf_t& d1) const { - const bhalf_t y = type_convert(c) * d0 + d1; - e = y; + const bhalf_t y = + type_convert(c * type_convert(d0) + type_convert(d1)); + e = y; } template <> __host__ __device__ void operator()(float& e, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp index c3c14edfb8d..9f7fd47083e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp @@ -334,14 +334,14 @@ struct GridwiseGemm_wmma_cshuffle_v3 struct Problem { __host__ Problem() = default; - __host__ Problem(index_t M_, - index_t N_, - index_t K_, - std::array StrideAs_, - std::array StrideBs_, - std::array StrideDs_, - index_t StrideE_, - index_t KBatch_) + __host__ __device__ Problem(index_t M_, + index_t N_, + index_t K_, + std::array StrideAs_, + std::array StrideBs_, + std::array StrideDs_, + index_t StrideE_, + index_t KBatch_) : M{M_}, N{N_}, K{K_}, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index 11e9a6dbf73..79549d63853 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -351,64 +351,65 @@ struct GridwiseGemm_wmma_cshuffle_v3_base // Calculate grid size taking into account splitk (KBatch) // 2D grid (x,z) - __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) + __host__ __device__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) { return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); } // Calculate grid size taking into account splitk (KBatch) and multiple groups (Batch) // 3D grid (x,y,z) - __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch, index_t Batch) + __host__ __device__ static auto + CalculateGridSize(index_t M, index_t N, index_t KBatch, index_t Batch) { return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), KBatch, Batch); } - __host__ static auto CalculateMPadded(index_t M) + __host__ __device__ static auto CalculateMPadded(index_t M) { return math::integer_least_multiple(M, MPerBlock); } - __host__ static auto CalculateNPadded(index_t N) + __host__ __device__ static auto CalculateNPadded(index_t N) { return math::integer_least_multiple(N, NPerBlock); } - __host__ static auto CalculateKPadded(index_t K) + __host__ __device__ static auto CalculateKPadded(index_t K) { return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; } - __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + __host__ __device__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) { auto K_t = K_Batch * KPerBlock; return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); } - __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + __host__ __device__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) { auto K_t = K_Batch * KPerBlock; return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); } - __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) { auto K_t = K_Batch * KPerBlock; return (K + K_t - 1) / K_t * KPerBlock; } - __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + __host__ __device__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) { constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); auto K_t = K_Batch * KReadVec; return (K + K_t - 1) / K_t * KReadVec; } - __host__ static auto CalculateMBlock(index_t M) + __host__ __device__ static auto CalculateMBlock(index_t M) { return math::integer_divide_ceil(M, MPerBlock); } - __host__ static auto CalculateNBlock(index_t N) + __host__ __device__ static auto CalculateNBlock(index_t N) { return math::integer_divide_ceil(N, NPerBlock); } @@ -963,14 +964,14 @@ struct GridwiseGemm_wmma_cshuffle_v3_base return true; } - __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) { const index_t num_loop = K / KPerBlock; return BlockwiseGemmPipe::BlockHasHotloop(num_loop); } - __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + __host__ __device__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) { const index_t num_loop = K / KPerBlock; diff --git a/include/ck/utility/tuple.hpp b/include/ck/utility/tuple.hpp index 78931407d82..16575950307 100644 --- a/include/ck/utility/tuple.hpp +++ b/include/ck/utility/tuple.hpp @@ -7,6 +7,7 @@ #include "ck/utility/sequence.hpp" #include "ck/utility/type.hpp" #include "ck/utility/enable_if.hpp" +#include namespace ck { @@ -220,4 +221,49 @@ constexpr Tuple tie(Args&... args) noexcept return {args...}; } +// +// tuple_map: Map tuple with a different type +// e.g. tuple_map> becomes Tuple, Wrapper, Wrapper> +// +template