Skip to content

Commit a217887

Browse files
njriasanfacebook-github-bot
authored andcommitted
Guard Stride tl.assume
Summary: Guards the tl.assume values generated by the strides behind a non-negative requirement. This is not due to any regression, just a reviewer suggestion. Differential Revision: D72398678
1 parent ff203fe commit a217887

File tree

1 file changed

+24
-6
lines changed

1 file changed

+24
-6
lines changed

Diff for: fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py

+24-6
Original file line numberDiff line numberDiff line change
@@ -1215,6 +1215,12 @@ def persistent_grid(META):
12151215
# GROUP_M=8,
12161216
# USE_BIAS=bias is not None,
12171217
AB_DTYPE=False,
1218+
NON_NEGATIVE_STRIDE_AM=a.stride(0) >= 0,
1219+
NON_NEGATIVE_STRIDE_AK=a.stride(1) >= 0,
1220+
NON_NEGATIVE_STRIDE_BN=b.stride(0) >= 0,
1221+
NON_NEGATIVE_STRIDE_BK=b.stride(1) >= 0,
1222+
NON_NEGATIVE_STRIDE_CM=c.stride(0) >= 0,
1223+
NON_NEGATIVE_STRIDE_CN=c.stride(1) >= 0,
12181224
)
12191225
elif use_warp_specialization:
12201226
assert has_warp_specialization
@@ -3363,6 +3369,12 @@ def _kernel_matmul_fp8_row_non_persistent(
33633369
SPLIT_K: tl.constexpr,
33643370
EVEN_K: tl.constexpr,
33653371
AB_DTYPE: tl.constexpr,
3372+
NON_NEGATIVE_STRIDE_AM: tl.constexpr,
3373+
NON_NEGATIVE_STRIDE_AK: tl.constexpr,
3374+
NON_NEGATIVE_STRIDE_BN: tl.constexpr,
3375+
NON_NEGATIVE_STRIDE_BK: tl.constexpr,
3376+
NON_NEGATIVE_STRIDE_CM: tl.constexpr,
3377+
NON_NEGATIVE_STRIDE_CN: tl.constexpr,
33663378
) -> None:
33673379
"""Matmul kernel of [M, K] @ [N, K] with row-wise scales
33683380
@@ -3400,12 +3412,18 @@ def _kernel_matmul_fp8_row_non_persistent(
34003412
tl.assume(M >= 0)
34013413
tl.assume(N >= 0)
34023414
tl.assume(K >= 0)
3403-
tl.assume(stride_am >= 0)
3404-
tl.assume(stride_ak >= 0)
3405-
tl.assume(stride_bn >= 0)
3406-
tl.assume(stride_bk >= 0)
3407-
tl.assume(stride_cm >= 0)
3408-
tl.assume(stride_cn >= 0)
3415+
if NON_NEGATIVE_STRIDE_AM:
3416+
tl.assume(stride_am >= 0)
3417+
if NON_NEGATIVE_STRIDE_AK:
3418+
tl.assume(stride_ak >= 0)
3419+
if NON_NEGATIVE_STRIDE_BN:
3420+
tl.assume(stride_bn >= 0)
3421+
if NON_NEGATIVE_STRIDE_BK:
3422+
tl.assume(stride_bk >= 0)
3423+
if NON_NEGATIVE_STRIDE_CM:
3424+
tl.assume(stride_cm >= 0)
3425+
if NON_NEGATIVE_STRIDE_CN:
3426+
tl.assume(stride_cn >= 0)
34093427
# Matrix multiplication.
34103428
pid = tl.program_id(0)
34113429
pid_z = tl.program_id(1)

0 commit comments

Comments
 (0)