From 8e763af431772f963dca71e8d037a71e0ceab8c2 Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Mon, 1 Dec 2025 09:25:43 +0000 Subject: [PATCH 01/14] Initial implementation removing cshuffle (without transpose) --- example/01_gemm/gemm_wmma_fp16_v3.cpp | 12 +- .../grid/epilogue_cshuffle_v3_wmma_base.hpp | 2 + .../gpu/grid/epilogue_direct_store.hpp | 145 +++++++ ...wise_ab_transfer_wave_tiles_interleave.hpp | 383 ++++++++++++++++++ .../gridwise_gemm_wmma_cshuffle_v3_common.hpp | 100 +++-- 5 files changed, 607 insertions(+), 35 deletions(-) create mode 100644 include/ck/tensor_operation/gpu/grid/epilogue_direct_store.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp diff --git a/example/01_gemm/gemm_wmma_fp16_v3.cpp b/example/01_gemm/gemm_wmma_fp16_v3.cpp index 5b10edd681a..2ad704cda8b 100644 --- a/example/01_gemm/gemm_wmma_fp16_v3.cpp +++ b/example/01_gemm/gemm_wmma_fp16_v3.cpp @@ -11,8 +11,8 @@ using AccDataType = float; using CShuffleDataType = ck::half_t; using CDataType = ck::half_t; -using ALayout = Col; -using BLayout = Row; +using ALayout = Row; +using BLayout = Col; using CLayout = Row; using AElementOp = PassThrough; @@ -31,10 +31,10 @@ using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuf 8, 8, 16, 16, 2, 8, - S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, - 1, 1, 8, 1, - S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, - 1, 1, 8, 1, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 1, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1>; diff --git a/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp b/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp index b8dd5905aa2..dd12cdca8c2 100644 --- a/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp +++ b/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp @@ -59,6 +59,8 @@ struct EpilogueCShuffleBase 1, CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma>>; + __device__ static constexpr bool IsLDSNeeded() { return true; } + // *Caution Here repeat is shuffle repeat __device__ static constexpr auto GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() diff --git a/include/ck/tensor_operation/gpu/grid/epilogue_direct_store.hpp b/include/ck/tensor_operation/gpu/grid/epilogue_direct_store.hpp new file mode 100644 index 00000000000..859225a831a --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/epilogue_direct_store.hpp @@ -0,0 +1,145 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" + +namespace ck { + +template +struct EpilogueDirectStore +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + + __device__ static constexpr bool IsLDSNeeded() { return false; } + + template + __device__ static void Run(CThreadBuf& c_thread_buf, + DsGridPointer, + EDataType* p_e_grid, + void*, + const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + CDEElementwiseOperation& cde_element_op, + const index_t& block_m_id, + const index_t& block_n_id) + { + auto e_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // C mapping in single thread. + constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = + BlockwiseGemmPipe:: + GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + // C mapping in single block + constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp = + BlockwiseGemmPipe:: + GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + constexpr auto MWave = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I1); + constexpr auto MSubGroup = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I2); + constexpr auto NWave = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I4); + constexpr auto NThreadPerSubGroup = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I5); + constexpr auto MAccVgprs = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I6); + + // origin + const auto c_thread_mtx_on_block = + BlockwiseGemmPipe::CalculateCThreadOriginDataIndex(I0, I0); + + const auto m_thread_data_on_grid_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MSubGroup, MAccVgprs))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_grid_idx = + m_thread_data_on_grid_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor.CalculateBottomIndex( + make_multi_index(c_thread_mtx_on_block[I0])); + + const auto n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NThreadPerSubGroup))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_grid_idx = + n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup_adaptor.CalculateBottomIndex( + make_multi_index(c_thread_mtx_on_block[I1])); + + // E grid descriptor + const auto c_grid_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = + transform_tensor_descriptor( + e_grid_desc_mblock_mperblock_nblock_nperblock, + make_tuple(make_freeze_transform(block_m_id), + make_unmerge_transform(make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_freeze_transform(block_n_id), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 1, 2, 6>{}, Sequence<>{}, Sequence<4, 5, 3>{})); + + auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + EDataType, + decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs), + decltype(c_grid_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs), + CDEElementwiseOperation, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 3, + NRepeat, // VectorSize + EGlobalMemoryDataOperation, + 1, + false>{c_grid_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + make_multi_index(m_thread_data_on_grid_idx[I0], + m_thread_data_on_grid_idx[I1], + m_thread_data_on_grid_idx[I2], + n_thread_data_on_grid_idx[I0], + n_thread_data_on_grid_idx[I1], + n_thread_data_on_grid_idx[I2], + m_thread_data_on_grid_idx[I3]), + cde_element_op}; + + c_thread_copy.Run( + c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + make_tuple(I0, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_grid_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + e_grid_buf); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp new file mode 100644 index 00000000000..9404271dad2 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp @@ -0,0 +1,383 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/utility/amd_address_space.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp" +#include "ck/utility/math.hpp" + +namespace ck { + +template +struct ABTransferWaveTilesInterleave +{ + __device__ static constexpr bool IsLDSNeeded() { return true; } + + static_assert(!(is_same_v, pk_i4_t>), + "wave tile transfer method does not support pk_i4_t"); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + + static constexpr index_t MNKRow = 2; + + using ThisThreadBlock = ThisThreadBlock; + + static constexpr bool ABDoTranspose = !is_same_v; + static_assert(!ABDoTranspose, "wave tile interleaved transfer does not support transpose yet"); + + // Tiles distribution for global memory loading + // Notes: support for not power of 2 needs to be reviewed later on + // The tiles are distributed along the non-contiguous matrix dimension + // Example 4 waves A row-major MPerBlock = 64, KPerBlock = 64 + // MRepeat = 1, KRepeat = 4 + // ------------- + // |W0| | | | + // ------------- + // |W1| | | | + // ------------- + // |W2| | | | + // ------------- + // |W3| | | | + // ------------- + // Example 4 waves A column-major MPerBlock = 64, KPerBlock = 64 + // MRepeat = 4, KRepeat = 1 + // ------------- + // |W0|W1|W2|W3| + // ------------- + // | | | | | + // ------------- + // | | | | | + // ------------- + // | | | | | + // ------------- + static constexpr index_t NumberOfWaves = BlockSize / WaveSize; + static constexpr index_t MNMajorWaves_ = + MNPerBlock / MNPerWmma % std::min(MNPerBlock / MNPerWmma, NumberOfWaves) == 0 + ? std::min(MNPerBlock / MNPerWmma, NumberOfWaves) + : (MNPerBlock / MNPerWmma % 2 == 0 ? 2 : 1); + static constexpr index_t KMajorWaves_ = + KPerBlock / KPack % std::min(KPerBlock / KPack, NumberOfWaves) == 0 + ? std::min(KPerBlock / KPack, NumberOfWaves) + : (KPerBlock / KPack % 2 == 0 ? 2 : 1); + + static constexpr index_t MNWaves_Load = + ABDoTranspose ? NumberOfWaves / KMajorWaves_ : MNMajorWaves_; + static constexpr index_t KWaves_Load = + ABDoTranspose ? KMajorWaves_ : NumberOfWaves / MNMajorWaves_; + static constexpr index_t KRepeat_Load = KPerBlock / (KWaves_Load * KPack); + static constexpr index_t MNRepeat_Load = MNPerBlock / (MNWaves_Load * MNPerWmma); + + static constexpr index_t MNWaves_ = MNWaves_User; + static constexpr index_t KWaves_ = (BlockSize / WaveSize) / MNWaves_User; + static constexpr index_t KRepeat_ = KPerBlock / (KWaves_ * KPack); + static constexpr index_t MNRepeat_ = MNPerBlock / (MNWaves_ * MNPerWmma); + + template + __host__ __device__ static auto MakeGridDescriptor(GridDescriptorBase& base_desc, + index_t sizeMN, + index_t, + index_t sizeK, + index_t, + index_t, + index_t) + { + // Notes: padding is currently not supported + static_assert(!PadMN && !PadK, "padding is currently not supported"); + + // Divide the base descriptor MN_K into tiles + const auto ab_grid_desc_mntiles_ktiles = transform_tensor_descriptor( + base_desc, + make_tuple(make_unmerge_transform(make_tuple( + math::integer_divide_ceil(sizeMN, Number{}), + Number{})), + make_unmerge_transform(make_tuple( + math::integer_divide_ceil(sizeK, Number{}), Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{})); + + // The distinction is needed to get the same global indices for both layouts + // Divide each tile in 2 16x8 subtile + // MNTiles - KTiles - MNKRow - LaneLocal - VectorSize + // MNKRow = 0-1 + // LaneLocal = 0-15 + // VectorSize must be 8 + if constexpr(!ABDoTranspose) + { + const auto ab_grid_desc_mntiles_ktiles_mnrepeat = transform_tensor_descriptor( + ab_grid_desc_mntiles_ktiles, + make_tuple( + make_pass_through_transform( + math::integer_divide_ceil(sizeMN, Number{})), + make_pass_through_transform(math::integer_divide_ceil(sizeK, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<3, 2>{}, Sequence<4>{})); + + const auto ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1 = + transform_tensor_descriptor( + ab_grid_desc_mntiles_ktiles_mnrepeat, + make_tuple(make_pass_through_transform(math::integer_divide_ceil( + sizeMN, Number{})), + make_pass_through_transform( + math::integer_divide_ceil(sizeK, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}))), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4, 5>{})); + + // Freeze VectorSize to first element of the loading chunk (for convenience) + // Swap MNPerWmma and MNKRow for consistency with transpose descriptor + return transform_tensor_descriptor( + ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1, + make_tuple( + make_pass_through_transform( + math::integer_divide_ceil(sizeMN, Number{})), + make_pass_through_transform(math::integer_divide_ceil(sizeK, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_freeze_transform(I0)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<4>{}, + Sequence<3>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<>{})); + } + else + { + // TODO + } + } + + __device__ static constexpr auto GetBlockDescriptor() + { + // LDS memory layouts: + // lanes within tiles stored contiguously in chunks of 8 elements + // tiles are then stored first in K dimension + // MNTiles - KTiles - MNKRow - LaneLocal - VectorSize + const auto a_grid_desc_mraw_kraw = [&]() { + return make_naive_tensor_descriptor( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + I1)); + }(); + + // Freeze VectorSize to first element of the chunk (for convenience) + return transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_freeze_transform(I0)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<>{})); + } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MNWaves_Load, KWaves_Load, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto GetBlockLaneIdx() + { + const index_t lane_id = __lane_id(); + + constexpr index_t LanesPerSubTile = ABDoTranspose ? KPack : MNPerWmma; + + constexpr auto laneid_to_block_lane_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MNKRow, LanesPerSubTile))), + make_tuple(Sequence<0, 1>{}), + make_tuple(Sequence<0>{})); + + return laneid_to_block_lane_idx_adaptor.CalculateBottomIndex(make_multi_index(lane_id)); + } + + template + __device__ static auto GetGridLaneIdx() + { + const index_t lane_id = __lane_id(); + + constexpr index_t SubTilesRow = MNKRow; + constexpr index_t SubTilesCol = 4 / sizeof(ABDataType); + constexpr index_t LanesPerSubTile = + ABDoTranspose ? KPack / SubTilesCol : MNPerWmma / SubTilesCol; + constexpr auto dims_tuple = ABDoTranspose + ? make_tuple(SubTilesCol, SubTilesRow, LanesPerSubTile) + : make_tuple(SubTilesRow, SubTilesCol, LanesPerSubTile); + + constexpr auto laneid_to_grid_lane_idx_adaptor = + make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(dims_tuple)), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto indices = + laneid_to_grid_lane_idx_adaptor.CalculateBottomIndex(make_multi_index(lane_id)); + + if constexpr(!ABDoTranspose) + { + return make_multi_index(indices[I0], indices[I1] * LanesPerSubTile + indices[I2]); + } + else + { + return make_multi_index(indices[I1], indices[I0] * LanesPerSubTile + indices[I2]); + } + } + + template + __device__ static auto GetBlockTransfer(GridDescriptor& grid_descriptor, + BlockDescriptor& block_descriptor, + ABElementwiseOperation& ab_element_op, + const index_t block_mn_id, + const index_t) + { + // Note: GlobalBufferNum is currently not used but it will be needed + // once we add other pipelines. It is currently needed only for + // consistency with the thread tiles approach + static_assert(GlobalBufferNum == 1, "single global buffer is only supported"); + constexpr index_t NumABTensor = ABsDataType::Size(); + static_assert(NumABTensor == 1, "multiAB currently not supported"); + + using ABDataType = remove_cvref_t>; + + const auto wave_idx = GetWaveIdx(); + index_t wave_idK = wave_idx[I1]; + index_t wave_idMN = wave_idx[I0]; + + const auto grid_lane_id = GetGridLaneIdx(); + index_t lane_group_grid = grid_lane_id[I0]; + index_t lane_local_id_grid = grid_lane_id[I1]; + + const auto block_lane_id = GetBlockLaneIdx(); + index_t lane_group_block = block_lane_id[I0]; + index_t lane_local_id_block = block_lane_id[I1]; + + constexpr index_t MNRepeatRatio = MNRepeat_ / MNRepeat_Load; + return ThreadGroupTransferGlobal, + Sequence, + Sequence, + ABK1Value, + ABDoTranspose>( + grid_descriptor[I0], + block_descriptor, + make_multi_index(block_mn_id * MNWaves_ + wave_idMN / MNRepeatRatio, + wave_idK * KRepeat_, + (wave_idMN % MNRepeatRatio) * MNRepeat_Load, + lane_group_grid, + lane_local_id_grid), + make_multi_index(wave_idMN / MNRepeatRatio, + wave_idK * KRepeat_Load, + (wave_idMN % MNRepeatRatio) * MNRepeat_Load, + lane_group_block, + lane_local_id_block), + ab_element_op); + } + + template + __host__ __device__ static constexpr auto MakeWmmaTileDescriptor() + { + // This is a block descriptor used to read LDS memory into register + // It's defined in a way consistent with the existing implementation to + // avoid changes in the pipelines + return make_naive_tensor_descriptor(make_tuple(I1, + Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{}), + make_tuple(I0, + Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + I1)); + } + + __device__ static constexpr auto GetBlockStep() + { + // Grid descriptor step (MoveSrcSliceWindow) + return make_multi_index(I0, KWaves_Load * KRepeat_Load, I0, I0, I0); + } + + template + __device__ static constexpr index_t GetKDimension(const GridDescriptor& grid_desc) + { + return grid_desc.GetLength(I1) * KPack; + } + + template + __device__ static auto GetBuffer(LDSType* p_shared_AB, const IndexType& size) + { + return make_dynamic_buffer(p_shared_AB, size); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index 79549d63853..036b997a2f8 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -15,6 +15,7 @@ #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles_preshuffle.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp" @@ -24,6 +25,7 @@ #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_direct_store.hpp" #include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma.hpp" #include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_welford_wmma.hpp" #include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp" @@ -50,13 +52,19 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) std::is_same_v))) { #endif - constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< - typename GridwiseGemm::EpilogueCShuffle>(); + using EpilogueType = + typename std::conditional::type; + + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); __shared__ char p_shared[LDS_size]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + auto epilogue_args = EpilogueType{}; GridwiseGemm::template Run( p_shared, splitk_batch_offset, karg, epilogue_args); @@ -182,6 +190,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base static constexpr index_t NumATensor = AsDataType::Size(); static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); using LDSTypeA = typename std::conditional<(NumATensor > 1), @@ -248,9 +257,13 @@ struct GridwiseGemm_wmma_cshuffle_v3_base !ForceThreadTileTransfer && NumBTensor == 1 && BPackedSize == 1 && GemmSpec == tensor_operation::device::GemmSpecialization::Default && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && BK1Value == 8; + + static constexpr bool UseDirectStore = + NRepeat == 8 && sizeof(ComputeTypeB) == 2 && sizeof(EDataType) == 2 && NumDTensor == 0; #else static constexpr bool IsAWaveTransferApplicable = false; static constexpr bool IsBWaveTransferApplicable = false; + static constexpr bool UseDirectStore = false; #endif static constexpr index_t WaveSize = @@ -293,7 +306,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_base static constexpr bool UseBlockPaddingB = BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4; - using BTransfer = typename std::conditional< IsBPreShuffled, ABTransferThreadTilesPreShuffle, typename std::conditional< IsBWaveTransferApplicable, - ABTransferWaveTiles, + typename std::conditional< + UseDirectStore, + ABTransferWaveTilesInterleave, + ABTransferWaveTiles>::type, ABTransferThreadTiles; + using EpilogueDirectStore = EpilogueDirectStore; + using EpilogueWelfordCShuffle = EpilogueWelfordCShuffle< DsDataType, EDataType, @@ -1000,18 +1031,26 @@ struct GridwiseGemm_wmma_cshuffle_v3_base max_lds_align) : 0; - // LDS allocation for C shuffle in LDS - constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = - EpilogueType:: - GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); - - constexpr auto c_block_size = - c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat - .GetElementSpaceSize(); - - return math::max((a_block_space_size_aligned * sizeof(LDSTypeA) / APackedSize + - b_block_space_size_aligned * sizeof(LDSTypeB) / BPackedSize), - c_block_size * sizeof(CShuffleDataType)); + if constexpr(EpilogueType::IsLDSNeeded()) + { + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + EpilogueType:: + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat + .GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned * sizeof(LDSTypeA) / APackedSize + + b_block_space_size_aligned * sizeof(LDSTypeB) / BPackedSize), + c_block_size * sizeof(CShuffleDataType)); + } + else + { + return a_block_space_size_aligned * sizeof(LDSTypeA) / APackedSize + + b_block_space_size_aligned * sizeof(LDSTypeB) / BPackedSize; + } } template @@ -1148,7 +1187,10 @@ struct GridwiseGemm_wmma_cshuffle_v3_base num_k_block_main_loop, num_k_block_per_scale); - // shuffle C and write out + // Epilogue: + // - CShuffle / direct store + // - Multiple Ds + // - Fused operations epilogue_args.template Run( c_thread_buf, p_ds_grid, From a99008064b409b4d21adf676d90971dd82c9e723 Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Mon, 1 Dec 2025 17:17:47 +0000 Subject: [PATCH 02/14] Integrate cshuffle removal in fwd convoltion --- ...conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp | 79 +++++++++---------- .../gridwise_gemm_wmma_cshuffle_v3_common.hpp | 40 +++++++++- 2 files changed, 72 insertions(+), 47 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp index df128c10b9c..edba3749999 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp @@ -63,8 +63,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif kernel_grouped_conv_fwd_wmma_cshuffle_v3( typename GridwiseGemm::Argument karg, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const AGridDesc_AK0_M_AK1 a_grid_desc_m_k, + const BGridDesc_BK0_N_BK1 b_grid_desc_n_k, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock @@ -82,13 +82,26 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) std::is_same_v))) { #endif - __shared__ char p_shared[GridwiseGemm::template GetSharedMemoryNumberOfByte< - typename GridwiseGemm::EpilogueCShuffle>()]; + using EpilogueType = + typename std::conditional::type; - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); + __shared__ char p_shared[LDS_size]; - GridwiseGemm::template Run{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + return in_gemmm_gemmk_desc; } template @@ -337,16 +341,7 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 const auto wei_gemmn_gemmk_desc = matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); - const auto N = wei_gemmn_gemmk_desc.GetLength(I0); - const auto K = wei_gemmn_gemmk_desc.GetLength(I1); - - const auto BK0 = K / BK1; - - return transform_tensor_descriptor(wei_gemmn_gemmk_desc, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), - make_pass_through_transform(N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + return wei_gemmn_gemmk_desc; } template @@ -452,10 +447,10 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 BlkGemmPipelineVer, AComputeDataType, BComputeDataType, - false, // PermuteA - false, // PermuteB - false, // IsBPreShuffled - true>; // ForceThreadTileTransfer + false, // PermuteA + false, // PermuteB + false, // IsBPreShuffled + false>; // ForceThreadTileTransfer // TODO: Previously available template param DoElementwiseBeforeCShuffle! @@ -798,8 +793,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 } { - const index_t GemmM = a_grid_desc_ak0_m_ak1_.GetLength(I1); - const index_t GemmN = b_grid_desc_bk0_n_bk1_.GetLength(I1); + const index_t GemmM = a_grid_desc_ak0_m_ak1_.GetLength(I0); + const index_t GemmN = b_grid_desc_bk0_n_bk1_.GetLength(I0); const auto MBlock = CTranspose ? GridwiseGemmCTranspose::CalculateMBlock(GemmN) : GridwiseGemmCTranspose::CalculateMBlock(GemmM); const auto NBlock = CTranspose ? GridwiseGemmCTranspose::CalculateNBlock(GemmM) @@ -1048,10 +1043,9 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 constexpr index_t minimum_occupancy = BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; - const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); - const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_.GetLength(I1); - const index_t GemmK = - arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0); + const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_.GetLength(I0); + const index_t GemmK = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); const index_t num_workgroups_per_Conv_N = arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_; @@ -1985,10 +1979,9 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 } // check Gridwise GEMM - const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); - const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_.GetLength(I1); - const index_t GemmK = - arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0); + const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_.GetLength(I0); + const index_t GemmK = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); if constexpr(CTranspose) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index 036b997a2f8..8dc1c26d086 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -250,16 +250,22 @@ struct GridwiseGemm_wmma_cshuffle_v3_base #ifdef __gfx12__ static constexpr bool IsAWaveTransferApplicable = !ForceThreadTileTransfer && NumATensor == 1 && APackedSize == 1 && - GemmSpec == tensor_operation::device::GemmSpecialization::Default && + ((GemmSpec == tensor_operation::device::GemmSpecialization::Default && + !is_same_v) || + is_same_v) && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && AK1Value == 8 && !IsBPreShuffled; static constexpr bool IsBWaveTransferApplicable = !ForceThreadTileTransfer && NumBTensor == 1 && BPackedSize == 1 && - GemmSpec == tensor_operation::device::GemmSpecialization::Default && + ((GemmSpec == tensor_operation::device::GemmSpecialization::Default && + !is_same_v) || + is_same_v) && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && BK1Value == 8; - static constexpr bool UseDirectStore = - NRepeat == 8 && sizeof(ComputeTypeB) == 2 && sizeof(EDataType) == 2 && NumDTensor == 0; + // We need to investigate if it makes sense to remove cshuffle for smaller types + static constexpr bool UseDirectStore = is_same_v && + sizeof(ComputeTypeB) == 2 && sizeof(EDataType) == 2 && + NumDTensor == 0; #else static constexpr bool IsAWaveTransferApplicable = false; static constexpr bool IsBWaveTransferApplicable = false; @@ -515,6 +521,19 @@ struct GridwiseGemm_wmma_cshuffle_v3_base Number{}); } + template + __device__ static auto MakeAGridDescriptor_AK0_M_AK1(const GridDescBase& base_desc) + { + const auto M = base_desc.GetLength(I0); + const auto K = base_desc.GetLength(I1); + + const auto AK0 = K / AK1Value; + + constexpr bool padM = false; + constexpr bool padK = false; + return ATransfer::template MakeGridDescriptor(base_desc, M, 0, K, 0, 0, AK0); + } + __host__ __device__ static auto MakeBsGridDescriptor_BK0_N_BK1(const index_t K, const index_t KPad, @@ -541,6 +560,19 @@ struct GridwiseGemm_wmma_cshuffle_v3_base Number{}); } + template + __device__ static auto MakeBGridDescriptor_BK0_N_BK1(const GridDescBase& base_desc) + { + const auto N = base_desc.GetLength(I0); + const auto K = base_desc.GetLength(I1); + + const auto BK0 = K / BK1Value; + + constexpr bool padN = false; + constexpr bool padK = false; + return BTransfer::template MakeGridDescriptor(base_desc, N, 0, K, 0, 0, BK0); + } + __host__ __device__ static constexpr auto MakeAWmmaTileDescriptor() { constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma); From 87e313d9b1bc25fe268ba2835cd9dfe0f701bb5f Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Tue, 16 Dec 2025 09:04:33 +0000 Subject: [PATCH 03/14] Refactor interleaved wave transfer and add padding support for wave transfer --- .../grid/gridwise_ab_transfer_wave_tiles.hpp | 80 +++++- ...wise_ab_transfer_wave_tiles_interleave.hpp | 269 ++++++------------ .../gridwise_gemm_wmma_cshuffle_v3_common.hpp | 2 +- 3 files changed, 146 insertions(+), 205 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp index cf471578ca0..867244c2518 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp @@ -77,26 +77,76 @@ struct ABTransferWaveTiles static constexpr index_t KRepeat_ = KPerBlock / (KWaves_ * KPack); static constexpr index_t MNRepeat_ = MNPerBlock / (MNWaves_ * MNPerWmma); + template + __host__ __device__ static auto PadGridDescriptor(GridDescriptorBase& base_desc, + index_t sizeMN, + index_t MNPad, + index_t sizeK, + index_t KPad, + index_t, + index_t) + { + if constexpr(PadMN && PadK) + { + // pad both MN and K + return transform_tensor_descriptor( + base_desc, + make_tuple(make_right_pad_transform(sizeMN, MNPad - sizeMN), + make_right_pad_transform(sizeK, KPad - sizeK)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(PadMN && !PadK) + { + // pad MN, but not K + return transform_tensor_descriptor( + base_desc, + make_tuple(make_right_pad_transform(sizeMN, MNPad - sizeMN), + make_pass_through_transform(sizeK)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(!PadMN && PadK) + { + // pad K, but not MN + return transform_tensor_descriptor( + base_desc, + make_tuple(make_pass_through_transform(sizeMN), + make_right_pad_transform(sizeK, KPad - sizeK)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad MN or K + return base_desc; + } + } + template __host__ __device__ static auto MakeGridDescriptor(GridDescriptorBase& base_desc, index_t sizeMN, - index_t, + index_t MNPad, index_t sizeK, - index_t, + index_t KPad, index_t, index_t) { // Notes: padding is currently not supported - static_assert(!PadMN && !PadK, "padding is currently not supported"); + static_assert(!((PadMN || PadK) && ABDoTranspose), + "padding is currently not supported with transpose"); + + const auto base_desc_padded = + PadGridDescriptor(base_desc, sizeMN, MNPad, sizeK, KPad, 0, 0); // Divide the base descriptor MN_K into tiles const auto ab_grid_desc_mntiles_ktiles = transform_tensor_descriptor( - base_desc, + base_desc_padded, make_tuple( make_unmerge_transform(make_tuple( - math::integer_divide_ceil(sizeMN, Number{}), Number{})), - make_unmerge_transform(make_tuple(math::integer_divide_ceil(sizeK, Number{}), - Number{}))), + math::integer_divide_ceil(MNPad, Number{}), Number{})), + make_unmerge_transform( + make_tuple(math::integer_divide_ceil(KPad, Number{}), Number{}))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{})); @@ -112,9 +162,9 @@ struct ABTransferWaveTiles transform_tensor_descriptor( ab_grid_desc_mntiles_ktiles, make_tuple(make_pass_through_transform( - math::integer_divide_ceil(sizeMN, Number{})), + math::integer_divide_ceil(MNPad, Number{})), make_pass_through_transform( - math::integer_divide_ceil(sizeK, Number{})), + math::integer_divide_ceil(KPad, Number{})), make_pass_through_transform(Number{}), make_unmerge_transform( make_tuple(Number{}, Number{}))), @@ -127,8 +177,8 @@ struct ABTransferWaveTiles ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1, make_tuple( make_pass_through_transform( - math::integer_divide_ceil(sizeMN, Number{})), - make_pass_through_transform(math::integer_divide_ceil(sizeK, Number{})), + math::integer_divide_ceil(MNPad, Number{})), + make_pass_through_transform(math::integer_divide_ceil(KPad, Number{})), make_pass_through_transform(Number{}), make_pass_through_transform(Number{}), make_freeze_transform(I0)), @@ -143,9 +193,9 @@ struct ABTransferWaveTiles transform_tensor_descriptor( ab_grid_desc_mntiles_ktiles, make_tuple(make_pass_through_transform( - math::integer_divide_ceil(sizeMN, Number{})), + math::integer_divide_ceil(MNPad, Number{})), make_pass_through_transform( - math::integer_divide_ceil(sizeK, Number{})), + math::integer_divide_ceil(KPad, Number{})), make_unmerge_transform( make_tuple(Number{}, Number{})), make_pass_through_transform(Number{})), @@ -157,8 +207,8 @@ struct ABTransferWaveTiles ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1, make_tuple( make_pass_through_transform( - math::integer_divide_ceil(sizeMN, Number{})), - make_pass_through_transform(math::integer_divide_ceil(sizeK, Number{})), + math::integer_divide_ceil(MNPad, Number{})), + make_pass_through_transform(math::integer_divide_ceil(KPad, Number{})), make_pass_through_transform(Number{}), make_freeze_transform(I0), make_pass_through_transform(Number{})), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp index 9404271dad2..5829e490d6f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp @@ -5,6 +5,7 @@ #include "ck/utility/amd_address_space.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp" #include "ck/utility/math.hpp" namespace ck { @@ -19,94 +20,78 @@ template -struct ABTransferWaveTilesInterleave + index_t MNWaves_Gemm> +struct ABTransferWaveTilesInterleave : ABTransferWaveTiles { - __device__ static constexpr bool IsLDSNeeded() { return true; } + using Base = ABTransferWaveTiles; + + using Base::ABDoTranspose; + using Base::I0; + using Base::I1; + using Base::I2; + using Base::I3; + using Base::MNKRow; + + using Base::GetBlockLaneIdx; + using Base::GetBlockStep; + using Base::GetGridLaneIdx; + using Base::GetWaveIdx; + using Base::PadGridDescriptor; + using typename Base::ThisThreadBlock; - static_assert(!(is_same_v, pk_i4_t>), - "wave tile transfer method does not support pk_i4_t"); - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; static constexpr auto I4 = Number<4>{}; - static constexpr index_t MNKRow = 2; - - using ThisThreadBlock = ThisThreadBlock; - - static constexpr bool ABDoTranspose = !is_same_v; static_assert(!ABDoTranspose, "wave tile interleaved transfer does not support transpose yet"); - // Tiles distribution for global memory loading - // Notes: support for not power of 2 needs to be reviewed later on - // The tiles are distributed along the non-contiguous matrix dimension - // Example 4 waves A row-major MPerBlock = 64, KPerBlock = 64 - // MRepeat = 1, KRepeat = 4 - // ------------- - // |W0| | | | - // ------------- - // |W1| | | | - // ------------- - // |W2| | | | - // ------------- - // |W3| | | | - // ------------- - // Example 4 waves A column-major MPerBlock = 64, KPerBlock = 64 - // MRepeat = 4, KRepeat = 1 - // ------------- - // |W0|W1|W2|W3| - // ------------- - // | | | | | - // ------------- - // | | | | | - // ------------- - // | | | | | - // ------------- - static constexpr index_t NumberOfWaves = BlockSize / WaveSize; - static constexpr index_t MNMajorWaves_ = - MNPerBlock / MNPerWmma % std::min(MNPerBlock / MNPerWmma, NumberOfWaves) == 0 - ? std::min(MNPerBlock / MNPerWmma, NumberOfWaves) - : (MNPerBlock / MNPerWmma % 2 == 0 ? 2 : 1); - static constexpr index_t KMajorWaves_ = - KPerBlock / KPack % std::min(KPerBlock / KPack, NumberOfWaves) == 0 - ? std::min(KPerBlock / KPack, NumberOfWaves) - : (KPerBlock / KPack % 2 == 0 ? 2 : 1); + using Base::KRepeat_; + using Base::KWaves_; + using Base::MNRepeat_; - static constexpr index_t MNWaves_Load = - ABDoTranspose ? NumberOfWaves / KMajorWaves_ : MNMajorWaves_; - static constexpr index_t KWaves_Load = - ABDoTranspose ? KMajorWaves_ : NumberOfWaves / MNMajorWaves_; - static constexpr index_t KRepeat_Load = KPerBlock / (KWaves_Load * KPack); - static constexpr index_t MNRepeat_Load = MNPerBlock / (MNWaves_Load * MNPerWmma); - - static constexpr index_t MNWaves_ = MNWaves_User; - static constexpr index_t KWaves_ = (BlockSize / WaveSize) / MNWaves_User; - static constexpr index_t KRepeat_ = KPerBlock / (KWaves_ * KPack); - static constexpr index_t MNRepeat_ = MNPerBlock / (MNWaves_ * MNPerWmma); + static constexpr index_t MNWaves_Grid = MNWaves_Gemm; + static constexpr index_t KWaves_Grid = (BlockSize / WaveSize) / MNWaves_Gemm; + static constexpr index_t KRepeat_Grid = KPerBlock / (KWaves_Grid * KPack); + static constexpr index_t MNRepeat_Grid = MNPerBlock / (MNWaves_Grid * MNPerWmma); template __host__ __device__ static auto MakeGridDescriptor(GridDescriptorBase& base_desc, index_t sizeMN, - index_t, + index_t MNPad, index_t sizeK, - index_t, + index_t KPad, index_t, index_t) { // Notes: padding is currently not supported - static_assert(!PadMN && !PadK, "padding is currently not supported"); + // static_assert(!PadMN && !PadK, "padding is currently not supported"); + const auto base_desc_padded = Base::template PadGridDescriptor( + base_desc, sizeMN, MNPad, sizeK, KPad, 0, 0); // Divide the base descriptor MN_K into tiles const auto ab_grid_desc_mntiles_ktiles = transform_tensor_descriptor( - base_desc, + base_desc_padded, make_tuple(make_unmerge_transform(make_tuple( - math::integer_divide_ceil(sizeMN, Number{}), - Number{})), + math::integer_divide_ceil(MNPad, Number{}), + Number{})), make_unmerge_transform(make_tuple( - math::integer_divide_ceil(sizeK, Number{}), Number{}))), + math::integer_divide_ceil(KPad, Number{}), Number{}))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{})); @@ -122,9 +107,10 @@ struct ABTransferWaveTilesInterleave ab_grid_desc_mntiles_ktiles, make_tuple( make_pass_through_transform( - math::integer_divide_ceil(sizeMN, Number{})), - make_pass_through_transform(math::integer_divide_ceil(sizeK, Number{})), - make_unmerge_transform(make_tuple(Number{}, Number{})), + math::integer_divide_ceil(MNPad, Number{})), + make_pass_through_transform(math::integer_divide_ceil(KPad, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{})), make_pass_through_transform(Number{})), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<3, 2>{}, Sequence<4>{})); @@ -133,10 +119,10 @@ struct ABTransferWaveTilesInterleave transform_tensor_descriptor( ab_grid_desc_mntiles_ktiles_mnrepeat, make_tuple(make_pass_through_transform(math::integer_divide_ceil( - sizeMN, Number{})), + MNPad, Number{})), make_pass_through_transform( - math::integer_divide_ceil(sizeK, Number{})), - make_pass_through_transform(Number{}), + math::integer_divide_ceil(KPad, Number{})), + make_pass_through_transform(Number{}), make_pass_through_transform(Number{}), make_unmerge_transform( make_tuple(Number{}, Number{}))), @@ -154,9 +140,9 @@ struct ABTransferWaveTilesInterleave ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1, make_tuple( make_pass_through_transform( - math::integer_divide_ceil(sizeMN, Number{})), - make_pass_through_transform(math::integer_divide_ceil(sizeK, Number{})), - make_pass_through_transform(Number{}), + math::integer_divide_ceil(MNPad, Number{})), + make_pass_through_transform(math::integer_divide_ceil(KPad, Number{})), + make_pass_through_transform(Number{}), make_pass_through_transform(Number{}), make_pass_through_transform(Number{}), make_freeze_transform(I0)), @@ -173,10 +159,6 @@ struct ABTransferWaveTilesInterleave Sequence<4>{}, Sequence<>{})); } - else - { - // TODO - } } __device__ static constexpr auto GetBlockDescriptor() @@ -187,15 +169,15 @@ struct ABTransferWaveTilesInterleave // MNTiles - KTiles - MNKRow - LaneLocal - VectorSize const auto a_grid_desc_mraw_kraw = [&]() { return make_naive_tensor_descriptor( - make_tuple(Number{}, - Number{}, - Number{}, + make_tuple(Number{}, + Number{}, + Number{}, Number{}, Number{}, Number{}), - make_tuple(Number{}, + make_tuple(Number{}, Number{}, - Number{}, + Number{}, Number{}, Number{}, I1)); @@ -204,9 +186,9 @@ struct ABTransferWaveTilesInterleave // Freeze VectorSize to first element of the chunk (for convenience) return transform_tensor_descriptor( a_grid_desc_mraw_kraw, - make_tuple(make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), + make_tuple(make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), make_pass_through_transform(Number{}), make_pass_through_transform(Number{}), make_freeze_transform(I0)), @@ -224,63 +206,6 @@ struct ABTransferWaveTilesInterleave Sequence<>{})); } - __device__ static auto GetWaveIdx() - { - const index_t thread_id = ThisThreadBlock::GetThreadId(); - - constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(MNWaves_Load, KWaves_Load, WaveSize))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); - } - - __device__ static auto GetBlockLaneIdx() - { - const index_t lane_id = __lane_id(); - - constexpr index_t LanesPerSubTile = ABDoTranspose ? KPack : MNPerWmma; - - constexpr auto laneid_to_block_lane_idx_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(MNKRow, LanesPerSubTile))), - make_tuple(Sequence<0, 1>{}), - make_tuple(Sequence<0>{})); - - return laneid_to_block_lane_idx_adaptor.CalculateBottomIndex(make_multi_index(lane_id)); - } - - template - __device__ static auto GetGridLaneIdx() - { - const index_t lane_id = __lane_id(); - - constexpr index_t SubTilesRow = MNKRow; - constexpr index_t SubTilesCol = 4 / sizeof(ABDataType); - constexpr index_t LanesPerSubTile = - ABDoTranspose ? KPack / SubTilesCol : MNPerWmma / SubTilesCol; - constexpr auto dims_tuple = ABDoTranspose - ? make_tuple(SubTilesCol, SubTilesRow, LanesPerSubTile) - : make_tuple(SubTilesRow, SubTilesCol, LanesPerSubTile); - - constexpr auto laneid_to_grid_lane_idx_adaptor = - make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(dims_tuple)), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto indices = - laneid_to_grid_lane_idx_adaptor.CalculateBottomIndex(make_multi_index(lane_id)); - - if constexpr(!ABDoTranspose) - { - return make_multi_index(indices[I0], indices[I1] * LanesPerSubTile + indices[I2]); - } - else - { - return make_multi_index(indices[I1], indices[I0] * LanesPerSubTile + indices[I2]); - } - } - template (); + const auto grid_lane_id = Base::template GetGridLaneIdx(); index_t lane_group_grid = grid_lane_id[I0]; index_t lane_local_id_grid = grid_lane_id[I1]; @@ -313,70 +238,36 @@ struct ABTransferWaveTilesInterleave index_t lane_group_block = block_lane_id[I0]; index_t lane_local_id_block = block_lane_id[I1]; - constexpr index_t MNRepeatRatio = MNRepeat_ / MNRepeat_Load; + constexpr index_t MNRepeatRatio = MNRepeat_Grid / MNRepeat_; return ThreadGroupTransferGlobal, - Sequence, + Sequence, + Sequence, Sequence, ABK1Value, ABDoTranspose>( grid_descriptor[I0], block_descriptor, - make_multi_index(block_mn_id * MNWaves_ + wave_idMN / MNRepeatRatio, - wave_idK * KRepeat_, - (wave_idMN % MNRepeatRatio) * MNRepeat_Load, + make_multi_index(block_mn_id * MNWaves_Grid + wave_idMN / MNRepeatRatio, + wave_idK * KRepeat_Grid, + (wave_idMN % MNRepeatRatio) * MNRepeat_, lane_group_grid, lane_local_id_grid), make_multi_index(wave_idMN / MNRepeatRatio, - wave_idK * KRepeat_Load, - (wave_idMN % MNRepeatRatio) * MNRepeat_Load, + wave_idK * KRepeat_, + (wave_idMN % MNRepeatRatio) * MNRepeat_, lane_group_block, lane_local_id_block), ab_element_op); } - template - __host__ __device__ static constexpr auto MakeWmmaTileDescriptor() - { - // This is a block descriptor used to read LDS memory into register - // It's defined in a way consistent with the existing implementation to - // avoid changes in the pipelines - return make_naive_tensor_descriptor(make_tuple(I1, - Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - Number{}), - make_tuple(I0, - Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - I1)); - } - __device__ static constexpr auto GetBlockStep() { // Grid descriptor step (MoveSrcSliceWindow) - return make_multi_index(I0, KWaves_Load * KRepeat_Load, I0, I0, I0); - } - - template - __device__ static constexpr index_t GetKDimension(const GridDescriptor& grid_desc) - { - return grid_desc.GetLength(I1) * KPack; - } - - template - __device__ static auto GetBuffer(LDSType* p_shared_AB, const IndexType& size) - { - return make_dynamic_buffer(p_shared_AB, size); + return make_multi_index(I0, KWaves_ * KRepeat_, I0, I0, I0); } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index 8dc1c26d086..7f9d7ca8fd6 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -265,7 +265,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base // We need to investigate if it makes sense to remove cshuffle for smaller types static constexpr bool UseDirectStore = is_same_v && sizeof(ComputeTypeB) == 2 && sizeof(EDataType) == 2 && - NumDTensor == 0; + NumDTensor == 0 && (NRepeat % 2) == 0; #else static constexpr bool IsAWaveTransferApplicable = false; static constexpr bool IsBWaveTransferApplicable = false; From e889ff14ee61ef599921c6c8ecbbff71691d2963 Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Tue, 16 Dec 2025 11:13:50 +0000 Subject: [PATCH 04/14] Fix convolution fwd --- ...conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp | 85 +++++++++---------- .../gridwise_gemm_wmma_cshuffle_v3_common.hpp | 4 +- 2 files changed, 42 insertions(+), 47 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp index edba3749999..278b98fa32f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp @@ -48,8 +48,8 @@ namespace { * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run(). */ template {}; template - static auto - MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) - + static auto MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { namespace ctc = tensor_layout::convolution; using Layout = std::conditional_t< @@ -324,8 +322,7 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 } template - static auto - MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + static auto MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { namespace ctc = tensor_layout::convolution; using Layout = std::conditional_t< @@ -621,10 +618,10 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 I1>; // desc for blockwise copy - using AGridDesc_AK0_M_AK1 = remove_cvref_t( - dummy_conv_to_gemm_transformer))>; - using BGridDesc_BK0_N_BK1 = remove_cvref_t( - dummy_conv_to_gemm_transformer))>; + using AGridDesc_M_K = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; + using BGridDesc_N_K = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; // Argument struct Argument : public BaseArgument @@ -690,10 +687,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 ds_grid_desc_m_n_{}, e_grid_desc_m_n_{ DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_)}, - a_grid_desc_ak0_m_ak1_{ - MakeAGridDescriptor_AK0_M_AK1(conv_to_gemm_transformer_)}, - b_grid_desc_bk0_n_bk1_{ - MakeBGridDescriptor_BK0_N_BK1(conv_to_gemm_transformer_)}, + a_grid_desc_m_k_{MakeAGridDescriptor_M_K(conv_to_gemm_transformer_)}, + b_grid_desc_n_k_{MakeBGridDescriptor_N_K(conv_to_gemm_transformer_)}, compute_ptr_offset_of_groups_{}, compute_ptr_offset_of_n_{}, a_element_op_{a_element_op}, @@ -793,8 +788,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 } { - const index_t GemmM = a_grid_desc_ak0_m_ak1_.GetLength(I0); - const index_t GemmN = b_grid_desc_bk0_n_bk1_.GetLength(I0); + const index_t GemmM = a_grid_desc_m_k_.GetLength(I0); + const index_t GemmN = b_grid_desc_n_k_.GetLength(I0); const auto MBlock = CTranspose ? GridwiseGemmCTranspose::CalculateMBlock(GemmN) : GridwiseGemmCTranspose::CalculateMBlock(GemmM); const auto NBlock = CTranspose ? GridwiseGemmCTranspose::CalculateNBlock(GemmM) @@ -878,7 +873,7 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 is_same_v) { size_as_buffers[i] = - (a_grid_desc_ak0_m_ak1_.GetElementSpaceSize() + + (a_grid_desc_m_k_.GetElementSpaceSize() + (num_group_ - NumGroupsToMerge) * (a_g_n_c_wis_strides_[0])) * sizeof(ADataType_single) / GridwiseGemm::APackedSize; } @@ -886,13 +881,13 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 { if(CTranspose && a_g_n_c_wis_lengths_[I1] > 1) { - size_as_buffers[i] = (a_grid_desc_ak0_m_ak1_.GetElementSpaceSize() + + size_as_buffers[i] = (a_grid_desc_m_k_.GetElementSpaceSize() + (eff_num_group - 1) * (a_g_n_c_wis_strides_[0])) * sizeof(ADataType_single) / GridwiseGemm::APackedSize; } else { - size_as_buffers[i] = a_grid_desc_ak0_m_ak1_.GetElementSpaceSize() * + size_as_buffers[i] = a_grid_desc_m_k_.GetElementSpaceSize() * eff_num_group * sizeof(ADataType_single) / GridwiseGemm::APackedSize; } @@ -909,7 +904,7 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 static_for<0, NumBTensor, 1>{}([&](auto i) { using BDataType_single = remove_cvref_t>; - size_bs_buffers[i] = b_grid_desc_bk0_n_bk1_.GetElementSpaceSize() * eff_num_group * + size_bs_buffers[i] = b_grid_desc_n_k_.GetElementSpaceSize() * eff_num_group * sizeof(BDataType_single) / GridwiseGemm::BPackedSize; }); @@ -956,8 +951,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 void Print() const { - std::cout << "A[AK0, M, AK1]: " << a_grid_desc_ak0_m_ak1_ << std::endl; - std::cout << "B[BK0, N, BK1]: " << b_grid_desc_bk0_n_bk1_ << std::endl; + std::cout << "A[AK0, M, AK1]: " << a_grid_desc_m_k_ << std::endl; + std::cout << "B[BK0, N, BK1]: " << b_grid_desc_n_k_ << std::endl; static_for<0, NumDTensor, 1>{}( [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; @@ -993,8 +988,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 DsGridDesc_M_N ds_grid_desc_m_n_; EGridDesc_M_N e_grid_desc_m_n_; - AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; - BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock_; EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; @@ -1043,9 +1038,9 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 constexpr index_t minimum_occupancy = BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; - const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0); - const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_.GetLength(I0); - const index_t GemmK = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); + const index_t GemmM = arg.a_grid_desc_m_k_.GetLength(I0); + const index_t GemmN = arg.b_grid_desc_n_k_.GetLength(I0); + const index_t GemmK = arg.a_grid_desc_m_k_.GetLength(I1); const index_t num_workgroups_per_Conv_N = arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_; @@ -1187,8 +1182,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 dim3(BlockSize), 0, gemm_arg_, - arg.b_grid_desc_bk0_n_bk1_, - arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_n_k_, + arg.a_grid_desc_m_k_, arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_groups_, @@ -1204,8 +1199,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 dim3(BlockSize), 0, gemm_arg, - arg.b_grid_desc_bk0_n_bk1_, - arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_n_k_, + arg.a_grid_desc_m_k_, arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_groups_, @@ -1285,8 +1280,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 dim3(BlockSize), 0, gemm_arg_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, + arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_groups_, @@ -1302,8 +1297,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 dim3(BlockSize), 0, gemm_arg, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, + arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_groups_, @@ -1321,8 +1316,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 { const auto kernel = kernel_grouped_conv_fwd_wmma_cshuffle_v3< GridwiseGemmCTranspose, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_N_K, + DeviceOp::AGridDesc_M_K, DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, ComputePtrOffset, @@ -1336,8 +1331,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 { const auto kernel = kernel_grouped_conv_fwd_wmma_cshuffle_v3< GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::AGridDesc_M_K, + DeviceOp::BGridDesc_N_K, DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, ComputePtrOffset, @@ -1979,9 +1974,9 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 } // check Gridwise GEMM - const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0); - const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_.GetLength(I0); - const index_t GemmK = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); + const index_t GemmM = arg.a_grid_desc_m_k_.GetLength(I0); + const index_t GemmN = arg.b_grid_desc_n_k_.GetLength(I0); + const index_t GemmK = arg.a_grid_desc_m_k_.GetLength(I1); if constexpr(CTranspose) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index 7f9d7ca8fd6..e517a036b89 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -531,7 +531,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base constexpr bool padM = false; constexpr bool padK = false; - return ATransfer::template MakeGridDescriptor(base_desc, M, 0, K, 0, 0, AK0); + return ATransfer::template MakeGridDescriptor(base_desc, M, M, K, K, 0, AK0); } __host__ __device__ static auto @@ -570,7 +570,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base constexpr bool padN = false; constexpr bool padK = false; - return BTransfer::template MakeGridDescriptor(base_desc, N, 0, K, 0, 0, BK0); + return BTransfer::template MakeGridDescriptor(base_desc, N, N, K, K, 0, BK0); } __host__ __device__ static constexpr auto MakeAWmmaTileDescriptor() From b6d5215273f06b50b3e51167f36f8abf9e708e89 Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Tue, 16 Dec 2025 13:21:43 +0000 Subject: [PATCH 05/14] Add parameter to force thread tile to device struct --- ...grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp index 278b98fa32f..0cdb09e79dc 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp @@ -200,6 +200,7 @@ template ::value, Number<0>, @@ -444,10 +445,10 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 BlkGemmPipelineVer, AComputeDataType, BComputeDataType, - false, // PermuteA - false, // PermuteB - false, // IsBPreShuffled - false>; // ForceThreadTileTransfer + false, // PermuteA + false, // PermuteB + false, // IsBPreShuffled + UseThreadTileTransfer>; // ForceThreadTileTransfer // TODO: Previously available template param DoElementwiseBeforeCShuffle! @@ -521,7 +522,7 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 false, // PermuteB false, // PermuteA false, // IsBPreShuffled - true>; // ForceThreadTileTransfer + true>; // ForceThreadTileTransfer (always force it because of limitations in the transfer) using GridwiseGemmCTranspose = std::conditional_t; From 7993bc5e4d774b3952a048b1591ccb8166ac96d1 Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Thu, 18 Dec 2025 09:20:36 +0000 Subject: [PATCH 06/14] Add fast instances using wave transfer and direct store --- ...conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp | 14 +++- .../grid/gridwise_ab_transfer_wave_tiles.hpp | 25 +++--- ...wise_ab_transfer_wave_tiles_interleave.hpp | 21 ++--- .../gridwise_gemm_wmma_cshuffle_v3_common.hpp | 4 +- ...wmma_cshufflev3_wave_transfer_instance.hpp | 76 +++++++++++++++++++ .../gpu/grouped_convolution_forward.hpp | 4 + ...ed_convolution_forward_wmma_cshufflev3.inc | 28 +++++++ .../gpu/grouped_conv2d_fwd/CMakeLists.txt | 2 + ...ansfer_nhwgc_gkyxc_nhwgk_bf16_instance.cpp | 51 +++++++++++++ ...ransfer_nhwgc_gkyxc_nhwgk_f16_instance.cpp | 51 +++++++++++++ 10 files changed, 250 insertions(+), 26 deletions(-) create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_wave_transfer_instance.hpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_wave_transfer_nhwgc_gkyxc_nhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_wave_transfer_nhwgc_gkyxc_nhwgk_f16_instance.cpp diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp index 0cdb09e79dc..ee05c7c6a42 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp @@ -357,15 +357,21 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 const auto out_gemmmraw_gemmnraw_desc = conv_to_gemm_transformer.template MakeCDescriptor_M_N(); + // Force MN padding on the output tensor. This allows to use Gemm default or only K padding + // and remove some instructions in the hot loop (same approach used for gemm universal). if constexpr(CTranspose) { - constexpr auto matrix_padder_trans = - MatrixPadder{NPerBlock, MPerBlock, KPerBlock}; - return matrix_padder_trans.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); + constexpr auto matrix_padder_MN_padding_trans = + MatrixPadder{ + NPerBlock, MPerBlock, KPerBlock}; + return matrix_padder_MN_padding_trans.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); } else { - return matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); + constexpr auto matrix_padder_MN_padding = + MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + return matrix_padder_MN_padding.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); } } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp index 867244c2518..7c3c4f609b8 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp @@ -136,6 +136,9 @@ struct ABTransferWaveTiles static_assert(!((PadMN || PadK) && ABDoTranspose), "padding is currently not supported with transpose"); + const index_t MN_grid = PadMN ? sizeMN : MNPad; + const index_t K_grid = PadK ? sizeK : KPad; + const auto base_desc_padded = PadGridDescriptor(base_desc, sizeMN, MNPad, sizeK, KPad, 0, 0); @@ -144,9 +147,9 @@ struct ABTransferWaveTiles base_desc_padded, make_tuple( make_unmerge_transform(make_tuple( - math::integer_divide_ceil(MNPad, Number{}), Number{})), - make_unmerge_transform( - make_tuple(math::integer_divide_ceil(KPad, Number{}), Number{}))), + math::integer_divide_ceil(MN_grid, Number{}), Number{})), + make_unmerge_transform(make_tuple( + math::integer_divide_ceil(K_grid, Number{}), Number{}))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{})); @@ -162,9 +165,9 @@ struct ABTransferWaveTiles transform_tensor_descriptor( ab_grid_desc_mntiles_ktiles, make_tuple(make_pass_through_transform( - math::integer_divide_ceil(MNPad, Number{})), + math::integer_divide_ceil(MN_grid, Number{})), make_pass_through_transform( - math::integer_divide_ceil(KPad, Number{})), + math::integer_divide_ceil(K_grid, Number{})), make_pass_through_transform(Number{}), make_unmerge_transform( make_tuple(Number{}, Number{}))), @@ -177,8 +180,8 @@ struct ABTransferWaveTiles ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1, make_tuple( make_pass_through_transform( - math::integer_divide_ceil(MNPad, Number{})), - make_pass_through_transform(math::integer_divide_ceil(KPad, Number{})), + math::integer_divide_ceil(MN_grid, Number{})), + make_pass_through_transform(math::integer_divide_ceil(K_grid, Number{})), make_pass_through_transform(Number{}), make_pass_through_transform(Number{}), make_freeze_transform(I0)), @@ -193,9 +196,9 @@ struct ABTransferWaveTiles transform_tensor_descriptor( ab_grid_desc_mntiles_ktiles, make_tuple(make_pass_through_transform( - math::integer_divide_ceil(MNPad, Number{})), + math::integer_divide_ceil(MN_grid, Number{})), make_pass_through_transform( - math::integer_divide_ceil(KPad, Number{})), + math::integer_divide_ceil(K_grid, Number{})), make_unmerge_transform( make_tuple(Number{}, Number{})), make_pass_through_transform(Number{})), @@ -207,8 +210,8 @@ struct ABTransferWaveTiles ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1, make_tuple( make_pass_through_transform( - math::integer_divide_ceil(MNPad, Number{})), - make_pass_through_transform(math::integer_divide_ceil(KPad, Number{})), + math::integer_divide_ceil(MN_grid, Number{})), + make_pass_through_transform(math::integer_divide_ceil(K_grid, Number{})), make_pass_through_transform(Number{}), make_freeze_transform(I0), make_pass_through_transform(Number{})), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp index 5829e490d6f..c31a664e3a0 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp @@ -79,19 +79,20 @@ struct ABTransferWaveTilesInterleave : ABTransferWaveTiles( base_desc, sizeMN, MNPad, sizeK, KPad, 0, 0); + const index_t MN_grid = PadMN ? sizeMN : MNPad; + const index_t K_grid = PadK ? sizeK : KPad; + // Divide the base descriptor MN_K into tiles const auto ab_grid_desc_mntiles_ktiles = transform_tensor_descriptor( base_desc_padded, make_tuple(make_unmerge_transform(make_tuple( - math::integer_divide_ceil(MNPad, Number{}), + math::integer_divide_ceil(MN_grid, Number{}), Number{})), make_unmerge_transform(make_tuple( - math::integer_divide_ceil(KPad, Number{}), Number{}))), + math::integer_divide_ceil(K_grid, Number{}), Number{}))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{})); @@ -107,8 +108,8 @@ struct ABTransferWaveTilesInterleave : ABTransferWaveTiles{})), - make_pass_through_transform(math::integer_divide_ceil(KPad, Number{})), + math::integer_divide_ceil(MN_grid, Number{})), + make_pass_through_transform(math::integer_divide_ceil(K_grid, Number{})), make_unmerge_transform( make_tuple(Number{}, Number{})), make_pass_through_transform(Number{})), @@ -119,9 +120,9 @@ struct ABTransferWaveTilesInterleave : ABTransferWaveTiles{})), + MN_grid, Number{})), make_pass_through_transform( - math::integer_divide_ceil(KPad, Number{})), + math::integer_divide_ceil(K_grid, Number{})), make_pass_through_transform(Number{}), make_pass_through_transform(Number{}), make_unmerge_transform( @@ -140,8 +141,8 @@ struct ABTransferWaveTilesInterleave : ABTransferWaveTiles{})), - make_pass_through_transform(math::integer_divide_ceil(KPad, Number{})), + math::integer_divide_ceil(MN_grid, Number{})), + make_pass_through_transform(math::integer_divide_ceil(K_grid, Number{})), make_pass_through_transform(Number{}), make_pass_through_transform(Number{}), make_pass_through_transform(Number{}), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index e517a036b89..b549e03a714 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -263,9 +263,11 @@ struct GridwiseGemm_wmma_cshuffle_v3_base BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && BK1Value == 8; // We need to investigate if it makes sense to remove cshuffle for smaller types + // Currently we use direct store for NRepeat equal to 4 or 8. For 16 bit type we use at + // lease buffer store 64 bit for 16 contiguous threads -> 128 bytes in toral (full cache line) static constexpr bool UseDirectStore = is_same_v && sizeof(ComputeTypeB) == 2 && sizeof(EDataType) == 2 && - NumDTensor == 0 && (NRepeat % 2) == 0; + NumDTensor == 0 && (NRepeat == 4 || NRepeat == 8); #else static constexpr bool IsAWaveTransferApplicable = false; static constexpr bool IsBWaveTransferApplicable = false; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_wave_transfer_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_wave_transfer_instance.hpp new file mode 100644 index 00000000000..2529c55e31b --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_wave_transfer_instance.hpp @@ -0,0 +1,76 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using Empty_Tuple = ck::Tuple<>; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddClamp = ck::tensor_operation::element_wise::AddClamp; +using Clamp = ck::tensor_operation::element_wise::Clamp; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto ConvFwdOddC = + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; + +static constexpr auto GemmDefault = GemmSpecialization::Default; + +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +template , + typename OutElementOp = PassThrough> +using device_grouped_conv_fwd_wmma_cshufflev3_wave_transfer_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version | + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| | | + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| | | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index d38aa66ece0..08e2092c501 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -797,6 +797,8 @@ struct DeviceOperationInstanceFactory>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_cshufflev3_wave_transfer_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP16 @@ -117,6 +131,20 @@ void add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instanc PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_cshufflev3_wave_transfer_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances); #endif // grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt index 380c83fa929..4b8f1d1a160 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt @@ -125,6 +125,8 @@ set(GROUPED_CONV2D_FWD wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance_part4.cpp wmma/large_tensor/device_grouped_conv2d_fwd_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp wmma/large_tensor/device_grouped_conv2d_fwd_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_wave_transfer_nhwgc_gkyxc_nhwgk_bf16_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_wave_transfer_nhwgc_gkyxc_nhwgk_f16_instance.cpp ) # Add generated files for sharded instantiations. include(ShardInstantiation) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_wave_transfer_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_wave_transfer_nhwgc_gkyxc_nhwgk_bf16_instance.cpp new file mode 100644 index 00000000000..cbb4eae126c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_wave_transfer_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -0,0 +1,51 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_wave_transfer_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv2d_fwd_wmma_cshufflev3_wave_transfer_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_wave_transfer_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault, + GemmMNKPadding, + BF16>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_wave_transfer_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0, + GemmDefault, + BF16>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_wave_transfer_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_wave_transfer_nhwgc_gkyxc_nhwgk_f16_instance.cpp new file mode 100644 index 00000000000..099804294d7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_wave_transfer_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -0,0 +1,51 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_wave_transfer_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv2d_fwd_wmma_cshufflev3_wave_transfer_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_wave_transfer_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault, + GemmMNKPadding, + F16>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_wave_transfer_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0, + GemmDefault, + F16>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck From c1eb985d34feea65fee8e8e52c75f6c2982b9660 Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Fri, 19 Dec 2025 10:44:33 +0000 Subject: [PATCH 07/14] Fix bug --- .../gpu/grid/gridwise_ab_transfer_wave_tiles.hpp | 4 ++-- .../gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp index 7c3c4f609b8..279c08d4d40 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp @@ -136,8 +136,8 @@ struct ABTransferWaveTiles static_assert(!((PadMN || PadK) && ABDoTranspose), "padding is currently not supported with transpose"); - const index_t MN_grid = PadMN ? sizeMN : MNPad; - const index_t K_grid = PadK ? sizeK : KPad; + const index_t MN_grid = !PadMN ? sizeMN : MNPad; + const index_t K_grid = !PadK ? sizeK : KPad; const auto base_desc_padded = PadGridDescriptor(base_desc, sizeMN, MNPad, sizeK, KPad, 0, 0); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp index c31a664e3a0..bfe5b7bd08a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp @@ -82,8 +82,8 @@ struct ABTransferWaveTilesInterleave : ABTransferWaveTiles( base_desc, sizeMN, MNPad, sizeK, KPad, 0, 0); - const index_t MN_grid = PadMN ? sizeMN : MNPad; - const index_t K_grid = PadK ? sizeK : KPad; + const index_t MN_grid = !PadMN ? sizeMN : MNPad; + const index_t K_grid = !PadK ? sizeK : KPad; // Divide the base descriptor MN_K into tiles const auto ab_grid_desc_mntiles_ktiles = transform_tensor_descriptor( From f0f4f1c2357015bd8f0a713112782274aaec141e Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Fri, 19 Dec 2025 11:01:22 +0000 Subject: [PATCH 08/14] Restore example --- example/01_gemm/gemm_wmma_fp16_v3.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/example/01_gemm/gemm_wmma_fp16_v3.cpp b/example/01_gemm/gemm_wmma_fp16_v3.cpp index 2ad704cda8b..5b10edd681a 100644 --- a/example/01_gemm/gemm_wmma_fp16_v3.cpp +++ b/example/01_gemm/gemm_wmma_fp16_v3.cpp @@ -11,8 +11,8 @@ using AccDataType = float; using CShuffleDataType = ck::half_t; using CDataType = ck::half_t; -using ALayout = Row; -using BLayout = Col; +using ALayout = Col; +using BLayout = Row; using CLayout = Row; using AElementOp = PassThrough; @@ -31,10 +31,10 @@ using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuf 8, 8, 16, 16, 2, 8, - S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, - 2, 8, 8, 1, - S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, - 2, 8, 8, 1, + S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 1, 8, 1, + S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1>; From ad816e09fcc0f3dcbd25c8ed8dd56650b74b5d1d Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Fri, 19 Dec 2025 13:07:19 +0000 Subject: [PATCH 09/14] Fused kernels can not use direct store yet Need to add template parameter. It will be removed during refactoring --- ...device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp | 5 ++++- .../device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp | 5 ++++- .../gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp | 9 ++++++--- .../gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp | 6 ++++-- 4 files changed, 18 insertions(+), 7 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp index d35f22ba4ad..f0216c3f711 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp @@ -273,7 +273,10 @@ struct DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3 ComputeTypeA, ComputeTypeB, PermuteA, - PermuteB>; + PermuteB, + false, + false, + true>; // Welford 2nd part kernel template diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp index b64b72f4d4c..317c4073df9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp @@ -187,7 +187,10 @@ struct DeviceGemmReduce_Wmma_CShuffleV3 : public DeviceGemmReduce<0, ReduceOpera ComputeTypeA, ComputeTypeB, PermuteA, - PermuteB>; + PermuteB, + false, + false, + true>; using ReduceTrait = ReduceTrait_ + bool ForceThreadTileTransfer = false, + bool IsFusedKernel = false> struct GridwiseGemm_wmma_cshuffle_v3 : GridwiseGemm_wmma_cshuffle_v3_base< ALayout, @@ -231,7 +232,8 @@ struct GridwiseGemm_wmma_cshuffle_v3 PermuteA, PermuteB, IsBPreShuffled, - ForceThreadTileTransfer> + ForceThreadTileTransfer, + IsFusedKernel> { using Base = GridwiseGemm_wmma_cshuffle_v3_base< ALayout, @@ -285,7 +287,8 @@ struct GridwiseGemm_wmma_cshuffle_v3 PermuteA, PermuteB, IsBPreShuffled, - ForceThreadTileTransfer>; + ForceThreadTileTransfer, + IsFusedKernel>; using Base::I0; using Base::I1; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index b549e03a714..f1f498abf46 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -175,7 +175,8 @@ template // only needed for convolution (limitation) + bool ForceThreadTileTransfer = false, // only needed for convolution (limitation) + bool IsFusedKernel = false> struct GridwiseGemm_wmma_cshuffle_v3_base { @@ -267,7 +268,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base // lease buffer store 64 bit for 16 contiguous threads -> 128 bytes in toral (full cache line) static constexpr bool UseDirectStore = is_same_v && sizeof(ComputeTypeB) == 2 && sizeof(EDataType) == 2 && - NumDTensor == 0 && (NRepeat == 4 || NRepeat == 8); + NumDTensor == 0 && (NRepeat == 4 || NRepeat == 8) && + !IsFusedKernel; #else static constexpr bool IsAWaveTransferApplicable = false; static constexpr bool IsBWaveTransferApplicable = false; From 8d68c5596ae924f5eab98f7c77ce46718adf7f14 Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Fri, 19 Dec 2025 15:14:43 +0000 Subject: [PATCH 10/14] Fix for grouped gemm and add missing condition --- .../device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp | 12 ++++++++---- .../grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp | 12 ++++++++---- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp index 714d5670204..39024d39e43 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp @@ -41,8 +41,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const index_t group_count) { #if(defined(__gfx11__) || defined(__gfx12__)) - constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< - typename GridwiseGemm::EpilogueCShuffle>(); + using EpilogueType = typename std::conditional::type; + + constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte(); __shared__ char p_shared[LDS_size]; const index_t block_id = get_block_1d_id(); @@ -89,13 +93,13 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, tile_index[Number<0>{}]); - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + auto epilogue_args = EpilogueType{}; GridwiseGemm::template Run(static_cast(p_shared), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index f1f498abf46..c2c6fd776ba 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -242,6 +242,10 @@ struct GridwiseGemm_wmma_cshuffle_v3_base return 1; }(); + static constexpr index_t WaveSize = + WmmaSelector::selected_wmma + .wave_size; + // Limitations of the current implementation: // - no multiAB // - GemmSpecialization Default @@ -263,22 +267,22 @@ struct GridwiseGemm_wmma_cshuffle_v3_base is_same_v) && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && BK1Value == 8; + static constexpr bool IsWaveTileInterleavedFitting = + (NPerBlock / NPerWmma / NRepeat) * (KPerBlock / KPack) >= (BlockSize / WaveSize); + // We need to investigate if it makes sense to remove cshuffle for smaller types // Currently we use direct store for NRepeat equal to 4 or 8. For 16 bit type we use at // lease buffer store 64 bit for 16 contiguous threads -> 128 bytes in toral (full cache line) static constexpr bool UseDirectStore = is_same_v && sizeof(ComputeTypeB) == 2 && sizeof(EDataType) == 2 && NumDTensor == 0 && (NRepeat == 4 || NRepeat == 8) && - !IsFusedKernel; + !IsFusedKernel && IsWaveTileInterleavedFitting; #else static constexpr bool IsAWaveTransferApplicable = false; static constexpr bool IsBWaveTransferApplicable = false; static constexpr bool UseDirectStore = false; #endif - static constexpr index_t WaveSize = - WmmaSelector::selected_wmma - .wave_size; static constexpr bool UseBlockPaddingA = ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4; using ATransfer = typename std::conditional< From 0ecbef71f965f3909bf3e1f00ce2f63d382303e0 Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Fri, 19 Dec 2025 15:44:44 +0000 Subject: [PATCH 11/14] Fix batched gemm --- .../impl/device_batched_gemm_wmma_cshuffle_v3.hpp | 12 +++++++++--- .../device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp | 12 +++++++++--- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp index 11e2add1324..a18f108e473 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp @@ -60,8 +60,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const long_index_t c_batch_offset = amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)); - constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< - typename GridwiseGemm::EpilogueCShuffle>(); + using EpilogueType = + typename std::conditional::type; + + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); __shared__ char p_shared[LDS_size]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); @@ -84,7 +90,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) splitk_batch_offset.b_k_split_offset[i] + b_batch_offset; }); - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + auto epilogue_args = EpilogueType{}; GridwiseGemm::template Run( p_as_grid_shift, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp index ee1ddc494d8..b88f071a962 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp @@ -46,8 +46,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) std::is_same_v))) { #endif - constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< - typename GridwiseGemm::EpilogueCShuffle>(); + using EpilogueType = + typename std::conditional::type; + + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); // The normal approach to batching would be to increase the grid size by just stretching out // the grid Z dimension (which is the outermost dimension), but this depends on lower level // functions not directly using the Z dimension for other calculations. As it turns out, k @@ -86,7 +92,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) splitk_batch_offset.b_k_split_offset[i] + b_batch_offset; }); - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + auto epilogue_args = EpilogueType{}; GridwiseGemm::template Run( p_as_grid_shift, From 1c87ebd3c9881d82c7cc8d3f7276e70e95813087 Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Tue, 6 Jan 2026 11:27:36 +0000 Subject: [PATCH 12/14] Fix comments --- .../gpu/grid/gridwise_ab_transfer_wave_tiles.hpp | 2 +- .../gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp | 7 ++----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp index 279c08d4d40..e47bb37a899 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp @@ -132,7 +132,7 @@ struct ABTransferWaveTiles index_t, index_t) { - // Notes: padding is currently not supported + // Notes: padding is currently not supported with transpose static_assert(!((PadMN || PadK) && ABDoTranspose), "padding is currently not supported with transpose"); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index c2c6fd776ba..ec7710d066c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -248,10 +248,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base // Limitations of the current implementation: // - no multiAB - // - GemmSpecialization Default - // - pipeline v1 because v3 is buggy (fixed in batched gemm gemm implementation) - // AK1Value == 8 is not really a limitation but a requirement for the method so - // it will stay + // - GemmSpecialization Default with transpose #ifdef __gfx12__ static constexpr bool IsAWaveTransferApplicable = !ForceThreadTileTransfer && NumATensor == 1 && APackedSize == 1 && @@ -272,7 +269,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base // We need to investigate if it makes sense to remove cshuffle for smaller types // Currently we use direct store for NRepeat equal to 4 or 8. For 16 bit type we use at - // lease buffer store 64 bit for 16 contiguous threads -> 128 bytes in toral (full cache line) + // least buffer store 64 bit for 16 contiguous threads -> 128 bytes in total (full cache line) static constexpr bool UseDirectStore = is_same_v && sizeof(ComputeTypeB) == 2 && sizeof(EDataType) == 2 && NumDTensor == 0 && (NRepeat == 4 || NRepeat == 8) && From e05be3588d71bd0b740f34b1fdfc7773278996aa Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Fri, 9 Jan 2026 15:45:58 +0000 Subject: [PATCH 13/14] Fix for gemm_bias_add_reduce flavour --- .../impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp index c64a1d504dc..e8e3b69cb5f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp @@ -188,7 +188,10 @@ struct DeviceGemmBiasAddReduce_Wmma_CShuffleV3 ComputeTypeA, ComputeTypeB, PermuteA, - PermuteB>; + PermuteB, + false, // IsBPreShuffled + false, // ForceThreadTileTransfer + true>; // IsFusedKernel using ReduceTrait = ReduceTrait_ Date: Tue, 13 Jan 2026 08:34:06 +0000 Subject: [PATCH 14/14] Fix grouped gemm tile loop --- ...uped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp index b7c0d89e0f3..5ae9eaf8aca 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp @@ -66,8 +66,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const CDEElementwiseOperation cde_element_op) { #if(defined(__gfx11__) || defined(__gfx12__)) - constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< - typename GridwiseGemm::EpilogueCShuffle>(); + using EpilogueType = typename std::conditional::type; + + constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte(); __shared__ uint8_t p_shared[LDS_size]; const auto gemm_desc_ptr = @@ -150,7 +154,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) gemm_desc_ptr[group_id].StrideE, 1); - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + auto epilogue_args = EpilogueType{}; constexpr TailNumber TailNum = TailNumber::Full; if(has_main_k_block_loop)