@@ -1215,6 +1215,12 @@ def persistent_grid(META):
1215
1215
# GROUP_M=8,
1216
1216
# USE_BIAS=bias is not None,
1217
1217
AB_DTYPE = False ,
1218
+ NON_NEGATIVE_STRIDE_AM = a .stride (0 ) >= 0 ,
1219
+ NON_NEGATIVE_STRIDE_AK = a .stride (1 ) >= 0 ,
1220
+ NON_NEGATIVE_STRIDE_BN = b .stride (0 ) >= 0 ,
1221
+ NON_NEGATIVE_STRIDE_BK = b .stride (1 ) >= 0 ,
1222
+ NON_NEGATIVE_STRIDE_CM = c .stride (0 ) >= 0 ,
1223
+ NON_NEGATIVE_STRIDE_CN = c .stride (1 ) >= 0 ,
1218
1224
)
1219
1225
elif use_warp_specialization :
1220
1226
assert has_warp_specialization
@@ -3363,6 +3369,12 @@ def _kernel_matmul_fp8_row_non_persistent(
3363
3369
SPLIT_K : tl .constexpr ,
3364
3370
EVEN_K : tl .constexpr ,
3365
3371
AB_DTYPE : tl .constexpr ,
3372
+ NON_NEGATIVE_STRIDE_AM : tl .constexpr ,
3373
+ NON_NEGATIVE_STRIDE_AK : tl .constexpr ,
3374
+ NON_NEGATIVE_STRIDE_BN : tl .constexpr ,
3375
+ NON_NEGATIVE_STRIDE_BK : tl .constexpr ,
3376
+ NON_NEGATIVE_STRIDE_CM : tl .constexpr ,
3377
+ NON_NEGATIVE_STRIDE_CN : tl .constexpr ,
3366
3378
) -> None :
3367
3379
"""Matmul kernel of [M, K] @ [N, K] with row-wise scales
3368
3380
@@ -3400,12 +3412,18 @@ def _kernel_matmul_fp8_row_non_persistent(
3400
3412
tl .assume (M >= 0 )
3401
3413
tl .assume (N >= 0 )
3402
3414
tl .assume (K >= 0 )
3403
- tl .assume (stride_am >= 0 )
3404
- tl .assume (stride_ak >= 0 )
3405
- tl .assume (stride_bn >= 0 )
3406
- tl .assume (stride_bk >= 0 )
3407
- tl .assume (stride_cm >= 0 )
3408
- tl .assume (stride_cn >= 0 )
3415
+ if NON_NEGATIVE_STRIDE_AM :
3416
+ tl .assume (stride_am >= 0 )
3417
+ if NON_NEGATIVE_STRIDE_AK :
3418
+ tl .assume (stride_ak >= 0 )
3419
+ if NON_NEGATIVE_STRIDE_BN :
3420
+ tl .assume (stride_bn >= 0 )
3421
+ if NON_NEGATIVE_STRIDE_BK :
3422
+ tl .assume (stride_bk >= 0 )
3423
+ if NON_NEGATIVE_STRIDE_CM :
3424
+ tl .assume (stride_cm >= 0 )
3425
+ if NON_NEGATIVE_STRIDE_CN :
3426
+ tl .assume (stride_cn >= 0 )
3409
3427
# Matrix multiplication.
3410
3428
pid = tl .program_id (0 )
3411
3429
pid_z = tl .program_id (1 )
0 commit comments