Skip to content

Commit 3f263d7

Browse files
add dim=0 to flip to fix triton compile error
1 parent a369a0f commit 3f263d7

2 files changed

Lines changed: 2 additions & 2 deletions

File tree

mlstm_kernels/triton/chunkwise/xl_chunk/fw_kernel_recurrent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def mlstm_chunkwise__recurrent_fw_C_kernel(
163163
tl.float32
164164
)
165165

166-
vecA_k_val = tl.flip(tl.cumsum(tl.flip(vecFlogsig_masked), axis=0)) + vecI_k_val
166+
vecA_k_val = tl.flip(tl.cumsum(tl.flip(vecFlogsig_masked, dim=0), axis=0), dim=0) + vecI_k_val
167167

168168
vecFfirst_k_val = tl.load(vecF + idx_b_BNH * str_vecFI_B_NH + k * L + 0).to(
169169
tl.float32

mlstm_kernels/triton/chunkwise/xl_chunk_siging/fw_kernel_recurrent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def mlstm_siging_chunkwise__recurrent_fw_C_kernel(
153153
vecIlogsig_k_val = tl.log(tl.sigmoid(vecI_k_val))
154154

155155
vecA_k_val = (
156-
tl.flip(tl.cumsum(tl.flip(vecFlogsig_masked), axis=0)) + vecIlogsig_k_val
156+
tl.flip(tl.cumsum(tl.flip(vecFlogsig_masked, dim=0), axis=0), dim=0) + vecIlogsig_k_val
157157
)
158158

159159
vecFfirst_k_val = tl.load(vecF + idx_b_BNH * str_vecFI_B_NH + k * L + 0).to(

0 commit comments

Comments
 (0)