|
43 | 43 | #include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" |
44 | 44 | #include "cutlass/epilogue/fusion/xe_visitor_softmax.hpp" |
45 | 45 | #include "cutlass/detail/layout.hpp" |
| 46 | +#include "../tools/util/include/cutlass/util/packed_stride.hpp" |
46 | 47 |
|
47 | 48 | #include "cute/tensor.hpp" |
48 | 49 |
|
@@ -114,6 +115,7 @@ class CollectiveEpilogue< |
114 | 115 | using ElementScalar = typename FusionCallbacks::ElementScalar; |
115 | 116 | static constexpr FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest; |
116 | 117 |
|
| 118 | + |
117 | 119 | static_assert(cute::is_any_of_v<typename FusionCallbacks::Operation, |
118 | 120 | fusion::LinearCombination<ElementAccumulator, ElementCompute, ElementSource, ElementScalar, RoundStyle, false>, |
119 | 121 | fusion::LinearCombination<ElementAccumulator, ElementCompute, ElementSource, ElementScalar, RoundStyle, true>>, |
@@ -244,6 +246,7 @@ class CollectiveEpilogue< |
244 | 246 | Arguments const& args) { |
245 | 247 | constexpr int copy_alignment_bits = 128; |
246 | 248 | constexpr int batch_alignment_bits = 512; |
| 249 | + |
247 | 250 | bool implementable = true; |
248 | 251 | bool fusion_implementable = true; |
249 | 252 |
|
@@ -493,22 +496,20 @@ template <typename ProblemShape_MNKL> |
493 | 496 | ElementC const *ptr_C_curr_batch = |
494 | 497 | reinterpret_cast<ElementC const *>(params.ptr_C[0]) + |
495 | 498 | cumulative_M * N; |
496 | | - auto c_stride = InternalStrideC{}; |
497 | | - cute::get<0>(c_stride) = N; |
498 | 499 | mC_mnl = make_tensor( |
499 | 500 | make_gmem_ptr(ptr_C_curr_batch), |
500 | | - make_layout(make_shape(M, N, L), c_stride)); |
| 501 | + make_layout(make_shape(M, N, L), cutlass::make_cute_packed_stride( |
| 502 | + InternalStrideC{}, {M, N, 1}))); |
501 | 503 | } |
502 | 504 |
|
503 | 505 | if constexpr (is_destination_supported) { |
504 | 506 | ElementD *ptr_D_curr_batch = |
505 | 507 | reinterpret_cast<ElementD *>(params.ptr_D[0]) + |
506 | 508 | cumulative_M * N; |
507 | | - auto d_stride = InternalStrideD{}; |
508 | | - cute::get<0>(d_stride) = N; |
509 | 509 | mD_mnl = make_tensor( |
510 | 510 | make_gmem_ptr(ptr_D_curr_batch), |
511 | | - make_layout(make_shape(M, N, L), d_stride)); |
| 511 | + make_layout(make_shape(M, N, L), cutlass::make_cute_packed_stride( |
| 512 | + InternalStrideD{}, {M, N, 1}))); |
512 | 513 | } |
513 | 514 | return cute::make_tuple(mC_mnl, mD_mnl); |
514 | 515 | } |
|
0 commit comments