From 63e77fa51a9b7aad7552fc955132d8822d4375d9 Mon Sep 17 00:00:00 2001 From: Nick Riasanovsky Date: Thu, 3 Apr 2025 11:12:19 -0700 Subject: [PATCH] Guard Stride tl.assume fp8_rowwise_gemm (#3924) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/1012 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3924 Guards the tl.assume values generated by the strides behind a non-negative requirement. This is not due to any regression, just a reviewer suggestion. Differential Revision: D72398678 --- .../experimental/gemm/triton_gemm/fp8_gemm.py | 29 +++++++++++++------ 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py index 954a8235ce..bd19599edd 100644 --- a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +++ b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py @@ -1190,6 +1190,14 @@ def persistent_grid(META): if bias is not None: raise AssertionError("bias is not supported in non-persistent kernel") # pyre-ignore + enable_buffer_ops_assumes = ( + a.stride(0) >= 0 + and a.stride(1) >= 0 + and b.stride(0) >= 0 + and b.stride(1) >= 0 + and c.stride(0) >= 0 + and c.stride(1) >= 0 + ) torch._library.capture_triton(_kernel_matmul_fp8_row_non_persistent)[grid]( a, b, @@ -1215,6 +1223,7 @@ def persistent_grid(META): # GROUP_M=8, # USE_BIAS=bias is not None, AB_DTYPE=False, + ENABLE_BUFFER_OPS_ASSUMES=enable_buffer_ops_assumes, ) elif use_warp_specialization: assert has_warp_specialization @@ -3363,6 +3372,7 @@ def _kernel_matmul_fp8_row_non_persistent( SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, AB_DTYPE: tl.constexpr, + ENABLE_BUFFER_OPS_ASSUMES: tl.constexpr, ) -> None: """Matmul kernel of [M, K] @ [N, K] with row-wise scales @@ -3397,15 +3407,16 @@ def _kernel_matmul_fp8_row_non_persistent( EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K. AB_DTYPE (bool): Wether to cast A and B to C.dtype before tensor core. """ - tl.assume(M >= 0) - tl.assume(N >= 0) - tl.assume(K >= 0) - tl.assume(stride_am >= 0) - tl.assume(stride_ak >= 0) - tl.assume(stride_bn >= 0) - tl.assume(stride_bk >= 0) - tl.assume(stride_cm >= 0) - tl.assume(stride_cn >= 0) + if ENABLE_BUFFER_OPS_ASSUMES: + tl.assume(M >= 0) + tl.assume(N >= 0) + tl.assume(K >= 0) + tl.assume(stride_am >= 0) + tl.assume(stride_ak >= 0) + tl.assume(stride_bn >= 0) + tl.assume(stride_bk >= 0) + tl.assume(stride_cm >= 0) + tl.assume(stride_cn >= 0) # Matrix multiplication. pid = tl.program_id(0) pid_z = tl.program_id(1)