Skip to content

Commit 5b6f880

Browse files
committed
Bug fix
1 parent 1008714 commit 5b6f880

File tree

2 files changed

+11
-12
lines changed

2 files changed

+11
-12
lines changed

include/cutlass/epilogue/collective/xe_array_epilogue.hpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp"
4444
#include "cutlass/epilogue/fusion/xe_visitor_softmax.hpp"
4545
#include "cutlass/detail/layout.hpp"
46+
#include "../tools/util/include/cutlass/util/packed_stride.hpp"
4647

4748
#include "cute/tensor.hpp"
4849

@@ -114,6 +115,7 @@ class CollectiveEpilogue<
114115
using ElementScalar = typename FusionCallbacks::ElementScalar;
115116
static constexpr FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest;
116117

118+
117119
static_assert(cute::is_any_of_v<typename FusionCallbacks::Operation,
118120
fusion::LinearCombination<ElementAccumulator, ElementCompute, ElementSource, ElementScalar, RoundStyle, false>,
119121
fusion::LinearCombination<ElementAccumulator, ElementCompute, ElementSource, ElementScalar, RoundStyle, true>>,
@@ -244,6 +246,7 @@ class CollectiveEpilogue<
244246
Arguments const& args) {
245247
constexpr int copy_alignment_bits = 128;
246248
constexpr int batch_alignment_bits = 512;
249+
247250
bool implementable = true;
248251
bool fusion_implementable = true;
249252

@@ -493,22 +496,20 @@ template <typename ProblemShape_MNKL>
493496
ElementC const *ptr_C_curr_batch =
494497
reinterpret_cast<ElementC const *>(params.ptr_C[0]) +
495498
cumulative_M * N;
496-
auto c_stride = InternalStrideC{};
497-
cute::get<0>(c_stride) = N;
498499
mC_mnl = make_tensor(
499500
make_gmem_ptr(ptr_C_curr_batch),
500-
make_layout(make_shape(M, N, L), c_stride));
501+
make_layout(make_shape(M, N, L), cutlass::make_cute_packed_stride(
502+
InternalStrideC{}, {M, N, 1})));
501503
}
502504

503505
if constexpr (is_destination_supported) {
504506
ElementD *ptr_D_curr_batch =
505507
reinterpret_cast<ElementD *>(params.ptr_D[0]) +
506508
cumulative_M * N;
507-
auto d_stride = InternalStrideD{};
508-
cute::get<0>(d_stride) = N;
509509
mD_mnl = make_tensor(
510510
make_gmem_ptr(ptr_D_curr_batch),
511-
make_layout(make_shape(M, N, L), d_stride));
511+
make_layout(make_shape(M, N, L), cutlass::make_cute_packed_stride(
512+
InternalStrideD{}, {M, N, 1})));
512513
}
513514
return cute::make_tuple(mC_mnl, mD_mnl);
514515
}

include/cutlass/gemm/collective/xe_array_mma.hpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
#include "cute/algorithm/functional.hpp"
3838
#include "cute/atom/mma_atom.hpp"
3939
#include "cute/algorithm/gemm.hpp"
40+
#include "../tools/util/include/cutlass/util/packed_stride.hpp"
4041

4142
/////////////////////////////////////////////////////////////////////////////////////////////////
4243

@@ -315,16 +316,13 @@ template <typename ProblemShape_MNKL>
315316
ElementB const *ptr_B_curr_batch =
316317
reinterpret_cast<ElementB const *>(mainloop_params.ptr_B[0]) +
317318
next_group * K * N;
318-
auto a_stride = InternalStrideA{};
319-
cute::get<0>(a_stride) = K;
319+
320320
Tensor mA = make_tensor(
321321
make_gmem_ptr(ptr_A_curr_batch), make_shape(M, K, (int32_t)1),
322-
a_stride);
323-
auto b_stride = InternalStrideB{};
324-
cute::get<0>(b_stride) = K;
322+
cutlass::make_cute_packed_stride(InternalStrideA{}, {M, K, 1}));
325323
Tensor mB = make_tensor(
326324
make_gmem_ptr(ptr_B_curr_batch), make_shape(N, K, (int32_t)1),
327-
b_stride);
325+
cutlass::make_cute_packed_stride(InternalStrideB{}, {N, K, 1}));
328326

329327
return cute::make_tuple(mA, mB);
330328
}

0 commit comments

Comments
 (0)