Skip to content

Commit 21a1bce

Browse files
committed
fix the template args
Signed-off-by: Chen, Xi2 <[email protected]>
1 parent f8a0514 commit 21a1bce

File tree

3 files changed

+4
-5
lines changed

3 files changed

+4
-5
lines changed

applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,7 @@ template <class DispatchPolicy_,
6262
class TensorV_,
6363
class TiledCopyQ_ = void, // Optional TiledCopy for loading Q
6464
class TiledCopyK_ = void, // Optional TiledCopy for loading K
65-
class TiledCopyV_ = void, // Optional TiledCopy for loading V
66-
class SubgroupLayoutQK_ = void> // Optional SubgroupLayout for QK
65+
class TiledCopyV_ = void> // Optional TiledCopy for loading V
6766
struct FMHAFwdMainloop {
6867
static_assert(cutlass::detail::dependent_false<DispatchPolicy_>, "Could not find a mainloop specialization.");
6968
};
@@ -195,7 +194,7 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
195194
Tensor cK = make_identity_tensor(K_2D.shape()); // (k,d)
196195
Tensor cV = make_identity_tensor(V_2D.shape()); // (v,k)
197196
Tensor cP = make_identity_tensor(take<0,2>(TileShapeQK{})); // (q,k)
198-
197+
199198
/* Partition global tensors into workgroup tiles */
200199
Tensor gQ = local_tile(cQ, TileShapeQK{}, append(blk_qv,_), Step<_1,X,_1>{}); // (q,d,D)
201200
Tensor gK = local_tile(cK, TileShapeQK{}, make_coord(_,_,_), Step<X,_1,_1>{}); // (k,d,K,D)

examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,6 @@ int main(int argc, const char **argv) {
147147
#else
148148
constexpr int PipelineStages = 2;
149149
#endif
150-
151150
#ifdef IS_FLOAT_E5M2
152151
using ElementQ = cutlass::float_e5m2_t;
153152
using ElementK = cutlass::float_e5m2_t;

examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,8 @@ struct FMHAConfig {
576576
MainloopDispatchPolicy, Causal,
577577
TiledMMAQK, TiledMMAPV, VTiles,
578578
TensorQ, TensorK, TensorV,
579-
GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV>;
579+
GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV
580+
>;
580581

581582
// Epilogue
582583
using CollectiveEpilogue = cutlass::fmha::collective::FMHAFwdEpilogue<

0 commit comments

Comments
 (0)