From 2797d132de3800d99ebc0d7498252b5b520c4b07 Mon Sep 17 00:00:00 2001 From: ayoub-louati Date: Thu, 5 Jan 2023 14:28:44 +0100 Subject: [PATCH 01/23] feat: add linear layer backward + activations grad --- src/kernl/implementations/activation_func.py | 23 +++ src/kernl/implementations/linear_layer.py | 196 ++++++++++++++++++- test/test_activations.py | 0 3 files changed, 217 insertions(+), 2 deletions(-) create mode 100644 test/test_activations.py diff --git a/src/kernl/implementations/activation_func.py b/src/kernl/implementations/activation_func.py index 53408a5e..c2b0d36d 100644 --- a/src/kernl/implementations/activation_func.py +++ b/src/kernl/implementations/activation_func.py @@ -23,6 +23,7 @@ sqrt2pi = math.sqrt(2.0 / math.pi) sqrt2 = math.sqrt(2.0) +gaussian_pdf_normalization = 1.0 / math.sqrt(2 * math.pi) @triton.jit @@ -31,19 +32,41 @@ def tanh(x): return tl.libdevice.tanh(x) +@triton.jit +def tanh_grad(x): + """Tanh derivative function""" + return 1 - tl.libdevice.pow(tl.libdevice.tanh(x), 2) + + @triton.jit def relu(x): """Relu activation function""" return tl.maximum(0, x) +@triton.jit +def relu_grad(x): + """Relu derivative function""" + return tl.maximum(0, x) @triton.jit def fast_gelu(x): """Fast approximation of the gelu function. May slightly decrease accuracy.""" return 0.5 * x * (1 + tanh(sqrt2pi * (x + 0.044715 * x * x * x))) +@triton.jit +def fast_gelu_grad(x): + """Derivative of fast approximation of the gelu function.""" + raise NotImplemented() @triton.jit def gelu(x): """Gaussian Error Linear Unit (GELU)""" return x * 0.5 * (1.0 + tl.libdevice.erf(x / sqrt2)) + + +@triton.jit +def gelu_grad(x): + """Derivative of Gaussian Error Linear Unit (GELU)""" + cdf = 0.5 * (1.0 + tl.libdevice.erf(x * sqrt2)) + pdf = tl.exp(-0.5 * x * x) * gaussian_pdf_normalization + return cdf + x * pdf diff --git a/src/kernl/implementations/linear_layer.py b/src/kernl/implementations/linear_layer.py index a30eebc6..5088a804 100644 --- a/src/kernl/implementations/linear_layer.py +++ b/src/kernl/implementations/linear_layer.py @@ -138,6 +138,7 @@ def kernel_fma( This kernel will consolidate over K """ + assert ACTIVATION in ["tanh", "gelu", "fast_gelu", "relu"], f"{ACTIVATION} is not supported" pid = tl.program_id(axis=0) grid_m = (M + BLOCK_M - 1) // BLOCK_M @@ -204,7 +205,7 @@ def kernel_fma( tl.store(C, acc, mask=mask) -class LinearLayer(torch.autograd.Function): +class LinearLayerFwd(torch.autograd.Function): @staticmethod @custom_fwd(cast_inputs=torch.float16) def forward( @@ -281,4 +282,195 @@ def linear_layer( activation="", act_inputs: Optional[torch.Tensor] = None, ) -> torch.Tensor: - return LinearLayer.apply(x, weight, bias, activation, act_inputs) + return LinearLayerFwd.apply(x, weight, bias, activation, act_inputs) + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2), + # good for int8 + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2), + ] + + get_configs_io_bound(), + key=["CACHE_KEY_M", "CACHE_KEY_N", "CACHE_KEY_K"], + prune_configs_by={"early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10}, +) +@triton.heuristics( + { + "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, + } +) +@triton.jit +def kernel_bwd( + C, # Pointers to matrices + ACT_INPUT, + A, + B, + # Matrix dimensions + M, + N, + K, + CACHE_KEY_M, + CACHE_KEY_N, + CACHE_KEY_K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. stride_am is how much to increase a_ptr + # by to get the element one row down (A has M rows) + stride_om, + stride_on, + stride_im, + stride_ik, + stride_wn, + stride_wk, + # Meta-parameters + BLOCK_M: tl.constexpr, + GROUP_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + # split k not used, not performant with activation, kept because early_config_prune is expecting it + SPLIT_K: tl.constexpr, + EVEN_K: tl.constexpr, + ACTIVATION: tl.constexpr, +): + pid = tl.program_id(axis=0) + + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + + # now compute the block that each program will go through + # rm (resp. rn) denotes a range of indices + # for rows (resp. col) of C + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + # trick to avoid masking on M and N axis + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = tl.arange(0, BLOCK_K) + + A = A + (ram[:, None] * stride_im + rk[None, :] * stride_ik) + B = B + (rk[:, None] * stride_wk + rbn[None, :] * stride_wn) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + for k in range(K, 0, -BLOCK_K): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + a = tl.load(A, mask=rk[None, :] < k, other=0.0) + b = tl.load(B, mask=rk[:, None] < k, other=0.0) + acc += tl.dot(a, b) + + A += BLOCK_K * stride_ik + B += BLOCK_K * stride_wk + + # optional: fused activation (while the data is in shared memory) + if ACTIVATION != "id": + act_in_ptrs = ACT_INPUT + ram[:, None] * stride_om + rbn[None, :] * stride_on + act_input = tl.load(act_in_ptrs).to(acc.dtype) + if ACTIVATION == "tanh": + acc *= activation_func.tanh_grad(act_input) + if ACTIVATION == "gelu": + acc *= activation_func.gelu_grad(act_input) + if ACTIVATION == "fast_gelu": + acc *= activation_func.fast_gelu_grad(act_input) + if ACTIVATION == "relu": + acc *= activation_func.relu_grad(act_input) + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + # write back result + C = C + rm[:, None] * stride_om + rn[None, :] * stride_on + mask = (rm < M)[:, None] & (rn < N)[None, :] + tl.store(C, acc, mask=mask) + + +def LinearlayerBwd( + grad_output: torch.Tensor, + weight: torch.Tensor, + activation: str = "id", + act_input: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + Compute e = activation(grad_output @ weight + bias). + This wrapper kicks the `kernel_fwd` Triton kernel + :param grad_output: input tensor + :param weight: weight matrix + :param activation: Activation name. Needs to be a Triton kernel. + :param act_input: an optional tensor to save the activation inputs (for backward) + :return: result tensor + """ + assert activation in ["id", "gelu", "gelu_approx", "squared_relu"] + + batch_shape, n = grad_output.shape[:-1], grad_output.shape[-1] + batch_dim = batch_shape.numel() + grad_output_reshaped = grad_output.reshape(batch_dim, n) + + if grad_output_reshaped.stride(0) > 1 and grad_output_reshaped.stride(1) > 1: + grad_output_reshaped = grad_output_reshaped.contiguous() + if weight.stride(0) > 1 and weight.stride(1) > 1: + weight = weight.contiguous() + + assert ( + grad_output.dtype == weight.dtype + ), f"grad_output and weight must have the same dtype, got {grad_output.dtype} and {weight.dtype}" + assert ( + grad_output_reshaped.shape[1] == weight.shape[0] + ), f"Incompatible dimensions: {grad_output_reshaped.shape} - {weight.shape}" + if activation != "id": + assert act_input is not None, f"act_input is required for activation {activation}" + + # M, N, K in bwd are different from M, N, K in fwd + M, K = grad_output_reshaped.shape + K, N = weight.shape + + grad_input = torch.empty((M, N), device=grad_output.device, dtype=grad_output.dtype) + + # 1D launch kernel where each block gets its own program. + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) # noqa + + kernel_bwd[grid]( + grad_input, + act_input, + grad_output_reshaped, + weight, # data ptrs + M, # shapes + N, + K, + M // 32, # key for triton cache (limit number of compilations) + N // 32, + K // 32, + stride_cm=grad_input.stride(0), # strides + # stride_cn=grad_input.stride(1), + stride_am=grad_output_reshaped.stride(0), + stride_ak=grad_output_reshaped.stride(1), + stride_bk=weight.stride(0), + stride_bn=weight.stride(1), + ACTIVATION=activation, # optional fused activation + GROUP_M=8, # speed optimization: group the programs + ) + + return grad_input.reshape(*batch_shape, grad_input.shape[-1]) diff --git a/test/test_activations.py b/test/test_activations.py new file mode 100644 index 00000000..e69de29b From d332ade3376d976d832f683a5bcbb94c2ace73c9 Mon Sep 17 00:00:00 2001 From: ayoub-louati Date: Thu, 5 Jan 2023 16:00:59 +0100 Subject: [PATCH 02/23] feat: add linear backwar wrapper +renaming + code formatting --- src/kernl/implementations/activation_func.py | 4 + src/kernl/implementations/linear_layer.py | 145 ++++++++++--------- src/kernl/optimizer/linear.py | 4 +- 3 files changed, 86 insertions(+), 67 deletions(-) diff --git a/src/kernl/implementations/activation_func.py b/src/kernl/implementations/activation_func.py index c2b0d36d..e0f51bee 100644 --- a/src/kernl/implementations/activation_func.py +++ b/src/kernl/implementations/activation_func.py @@ -43,21 +43,25 @@ def relu(x): """Relu activation function""" return tl.maximum(0, x) + @triton.jit def relu_grad(x): """Relu derivative function""" return tl.maximum(0, x) + @triton.jit def fast_gelu(x): """Fast approximation of the gelu function. May slightly decrease accuracy.""" return 0.5 * x * (1 + tanh(sqrt2pi * (x + 0.044715 * x * x * x))) + @triton.jit def fast_gelu_grad(x): """Derivative of fast approximation of the gelu function.""" raise NotImplemented() + @triton.jit def gelu(x): """Gaussian Error Linear Unit (GELU)""" diff --git a/src/kernl/implementations/linear_layer.py b/src/kernl/implementations/linear_layer.py index 5088a804..9459f3c0 100644 --- a/src/kernl/implementations/linear_layer.py +++ b/src/kernl/implementations/linear_layer.py @@ -22,7 +22,7 @@ import triton import triton.language as tl from torch.autograd.function import FunctionCtx -from torch.cuda.amp import custom_fwd +from torch.cuda.amp import custom_bwd, custom_fwd from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time from kernl.implementations import activation_func @@ -275,7 +275,7 @@ def forward( return outputs -def linear_layer( +def linear_layer_fwd( x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], @@ -284,6 +284,7 @@ def linear_layer( ) -> torch.Tensor: return LinearLayerFwd.apply(x, weight, bias, activation, act_inputs) + @triton.autotune( configs=[ triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8), @@ -408,69 +409,83 @@ def kernel_bwd( tl.store(C, acc, mask=mask) -def LinearlayerBwd( +class LinearLayerBwd(torch.autograd.Function): + @staticmethod + @custom_bwd + def backward( + ctx: FunctionCtx, + grad_outputs: torch.Tensor, + weight: torch.Tensor, + activation: str = "id", + act_inputs: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Compute e = activation(grad_output @ weight + bias). + This wrapper kicks the `kernel_fwd` Triton kernel + :param ctx: context for autograd + :param grad_outputs: input tensor + :param weight: weight matrix + :param activation: Activation name. Needs to be a Triton kernel. + :param act_inputs: an optional tensor to save the activation inputs (for backward) + :return: result tensor + """ + assert activation in ["id", "gelu", "gelu_approx", "squared_relu"] + + batch_shape, n = grad_outputs.shape[:-1], grad_outputs.shape[-1] + batch_dim = batch_shape.numel() + grad_output_reshaped = grad_outputs.reshape(batch_dim, n) + + if grad_output_reshaped.stride(0) > 1 and grad_output_reshaped.stride(1) > 1: + grad_output_reshaped = grad_output_reshaped.contiguous() + if weight.stride(0) > 1 and weight.stride(1) > 1: + weight = weight.contiguous() + + assert ( + grad_outputs.dtype == weight.dtype + ), f"grad_output and weight must have the same dtype, got {grad_outputs.dtype} and {weight.dtype}" + assert ( + grad_output_reshaped.shape[1] == weight.shape[0] + ), f"Incompatible dimensions: {grad_output_reshaped.shape} - {weight.shape}" + if activation != "id": + assert act_inputs is not None, f"act_input is required for activation {activation}" + + # M, N, K in bwd are different from M, N, K in fwd + M, K = grad_output_reshaped.shape + K, N = weight.shape + + grad_input = torch.empty((M, N), device=grad_outputs.device, dtype=grad_outputs.dtype) + + # 1D launch kernel where each block gets its own program. + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) # noqa + + kernel_bwd[grid]( + grad_input, + act_inputs, + grad_output_reshaped, + weight, # data ptrs + M, # shapes + N, + K, + M // 32, # key for triton cache (limit number of compilations) + N // 32, + K // 32, + stride_cm=grad_input.stride(0), # strides + # stride_cn=grad_input.stride(1), + stride_am=grad_output_reshaped.stride(0), + stride_ak=grad_output_reshaped.stride(1), + stride_bk=weight.stride(0), + stride_bn=weight.stride(1), + ACTIVATION=activation, # optional fused activation + GROUP_M=8, # speed optimization: group the programs + ) + + return grad_input.reshape(*batch_shape, grad_input.shape[-1]) + + +def linear_layer_bwd( grad_output: torch.Tensor, weight: torch.Tensor, - activation: str = "id", - act_input: Optional[torch.Tensor] = None, + activation="", + act_inputs: Optional[torch.Tensor] = None, ) -> torch.Tensor: - """ - Compute e = activation(grad_output @ weight + bias). - This wrapper kicks the `kernel_fwd` Triton kernel - :param grad_output: input tensor - :param weight: weight matrix - :param activation: Activation name. Needs to be a Triton kernel. - :param act_input: an optional tensor to save the activation inputs (for backward) - :return: result tensor - """ - assert activation in ["id", "gelu", "gelu_approx", "squared_relu"] - - batch_shape, n = grad_output.shape[:-1], grad_output.shape[-1] - batch_dim = batch_shape.numel() - grad_output_reshaped = grad_output.reshape(batch_dim, n) - - if grad_output_reshaped.stride(0) > 1 and grad_output_reshaped.stride(1) > 1: - grad_output_reshaped = grad_output_reshaped.contiguous() - if weight.stride(0) > 1 and weight.stride(1) > 1: - weight = weight.contiguous() - - assert ( - grad_output.dtype == weight.dtype - ), f"grad_output and weight must have the same dtype, got {grad_output.dtype} and {weight.dtype}" - assert ( - grad_output_reshaped.shape[1] == weight.shape[0] - ), f"Incompatible dimensions: {grad_output_reshaped.shape} - {weight.shape}" - if activation != "id": - assert act_input is not None, f"act_input is required for activation {activation}" - - # M, N, K in bwd are different from M, N, K in fwd - M, K = grad_output_reshaped.shape - K, N = weight.shape - - grad_input = torch.empty((M, N), device=grad_output.device, dtype=grad_output.dtype) - - # 1D launch kernel where each block gets its own program. - grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) # noqa - - kernel_bwd[grid]( - grad_input, - act_input, - grad_output_reshaped, - weight, # data ptrs - M, # shapes - N, - K, - M // 32, # key for triton cache (limit number of compilations) - N // 32, - K // 32, - stride_cm=grad_input.stride(0), # strides - # stride_cn=grad_input.stride(1), - stride_am=grad_output_reshaped.stride(0), - stride_ak=grad_output_reshaped.stride(1), - stride_bk=weight.stride(0), - stride_bn=weight.stride(1), - ACTIVATION=activation, # optional fused activation - GROUP_M=8, # speed optimization: group the programs - ) - - return grad_input.reshape(*batch_shape, grad_input.shape[-1]) + return LinearLayerBwd.apply(grad_output, weight, activation, act_inputs) diff --git a/src/kernl/optimizer/linear.py b/src/kernl/optimizer/linear.py index e9e6734a..f9c6406c 100644 --- a/src/kernl/optimizer/linear.py +++ b/src/kernl/optimizer/linear.py @@ -17,7 +17,7 @@ import torch -from kernl.implementations.linear_layer import linear_layer +from kernl.implementations.linear_layer import linear_layer_fwd from kernl.utils.extended_matcher import replace_pattern @@ -35,7 +35,7 @@ def linear_wrapper_functional(v: torch.Tensor, weight: torch.Tensor, bias: torch if bias is not None and bias.dtype == torch.float32: bias.data = bias.data.half() - return linear_layer(v, weight, bias, activation=activation) + return linear_layer_fwd(v, weight, bias, activation=activation) torch.fx.wrap("linear_wrapper_functional") From 5a6a4877c60fbfd8f3e5671e7facef9e1028d23d Mon Sep 17 00:00:00 2001 From: ayoub-louati Date: Thu, 5 Jan 2023 16:07:48 +0100 Subject: [PATCH 03/23] feat: fix code formatting --- src/kernl/implementations/activation_func.py | 2 +- src/kernl/implementations/linear_layer.py | 2 +- test/test_linear_layer.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/kernl/implementations/activation_func.py b/src/kernl/implementations/activation_func.py index e0f51bee..bb31a241 100644 --- a/src/kernl/implementations/activation_func.py +++ b/src/kernl/implementations/activation_func.py @@ -59,7 +59,7 @@ def fast_gelu(x): @triton.jit def fast_gelu_grad(x): """Derivative of fast approximation of the gelu function.""" - raise NotImplemented() + raise NotImplementedError() @triton.jit diff --git a/src/kernl/implementations/linear_layer.py b/src/kernl/implementations/linear_layer.py index 9459f3c0..37fd524d 100644 --- a/src/kernl/implementations/linear_layer.py +++ b/src/kernl/implementations/linear_layer.py @@ -441,7 +441,7 @@ def backward( weight = weight.contiguous() assert ( - grad_outputs.dtype == weight.dtype + grad_outputs.dtype == weight.dtype ), f"grad_output and weight must have the same dtype, got {grad_outputs.dtype} and {weight.dtype}" assert ( grad_output_reshaped.shape[1] == weight.shape[0] diff --git a/test/test_linear_layer.py b/test/test_linear_layer.py index 5e5a5539..d284d8ed 100644 --- a/test/test_linear_layer.py +++ b/test/test_linear_layer.py @@ -21,7 +21,7 @@ from conftest import assert_all_close, set_seed from kernl.implementations.cuda_graph import cuda_graphs_wrapper -from kernl.implementations.linear_layer import linear_layer +from kernl.implementations.linear_layer import linear_layer_fwd def get_pytorch_activation(activation: str) -> Callable: @@ -41,7 +41,7 @@ def get_pytorch_activation(activation: str) -> Callable: "pytorch": lambda weight, bias, activation: lambda x: get_pytorch_activation(activation)( torch.nn.functional.linear(x, weight, bias) ), - "triton": lambda weight, bias, activation: lambda x: linear_layer(x, weight, bias, activation), + "triton": lambda weight, bias, activation: lambda x: linear_layer_fwd(x, weight, bias, activation), } From 5a4eba6d152ef473a6d3c96a1e3b6c0331be5228 Mon Sep 17 00:00:00 2001 From: ayoub-louati Date: Thu, 5 Jan 2023 17:37:54 +0100 Subject: [PATCH 04/23] feat: add fast gelu grad --- src/kernl/implementations/activation_func.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/kernl/implementations/activation_func.py b/src/kernl/implementations/activation_func.py index bb31a241..20e2415b 100644 --- a/src/kernl/implementations/activation_func.py +++ b/src/kernl/implementations/activation_func.py @@ -59,7 +59,10 @@ def fast_gelu(x): @triton.jit def fast_gelu_grad(x): """Derivative of fast approximation of the gelu function.""" - raise NotImplementedError() + # CREDITS: Fast implementation proposed in + # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30 + tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + return 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) @triton.jit From 1587cffccec699ef6872ce0e21461bfdc2c67ec9 Mon Sep 17 00:00:00 2001 From: ayoub-louati Date: Tue, 10 Jan 2023 19:06:16 +0100 Subject: [PATCH 05/23] feat: fix linear layer (forward and backward) --- src/kernl/implementations/linear_layer.py | 169 +++++++++++----------- test/test_activations.py | 0 test/test_linear_layer.py | 6 +- 3 files changed, 87 insertions(+), 88 deletions(-) delete mode 100644 test/test_activations.py diff --git a/src/kernl/implementations/linear_layer.py b/src/kernl/implementations/linear_layer.py index 37fd524d..66f58dfb 100644 --- a/src/kernl/implementations/linear_layer.py +++ b/src/kernl/implementations/linear_layer.py @@ -138,7 +138,7 @@ def kernel_fma( This kernel will consolidate over K """ - assert ACTIVATION in ["tanh", "gelu", "fast_gelu", "relu"], f"{ACTIVATION} is not supported" + assert ACTIVATION in ["", "tanh", "gelu", "fast_gelu", "relu"], f"{ACTIVATION} is not supported" pid = tl.program_id(axis=0) grid_m = (M + BLOCK_M - 1) // BLOCK_M @@ -205,86 +205,6 @@ def kernel_fma( tl.store(C, acc, mask=mask) -class LinearLayerFwd(torch.autograd.Function): - @staticmethod - @custom_fwd(cast_inputs=torch.float16) - def forward( - ctx: FunctionCtx, - x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor], - activation: str, - act_inputs: Optional[torch.Tensor], - ) -> torch.Tensor: - """ - Compute e = activation(x @ weight + bias). - This wrapper kicks the `kernel_fma` Triton kernel - :param ctx: context for autograd - :param x: input tensor - :param weight: weight matrix - :param bias: an optional bias tensor - :param activation: Activation name. Needs to be a Triton kernel. - :param act_inputs: an optional tensor to save the activation inputs (for backward) - :return: result tensor - """ - x_ = x if x.ndim == 2 else x.flatten(0, 1) - - assert x.dtype == weight.dtype, f"Input and weight must have the same dtype, got {x.dtype} and {weight.dtype}" - if bias is not None: - assert x.dtype == bias.dtype, f"Input and bias must have the same dtype, got {x.dtype} and {bias.dtype}" - assert x_.shape[1] == weight.shape[1], f"Incompatible dimensions: {x_.shape} - {weight.shape}" - - assert bias is None or bias.is_contiguous() - assert bias is None or bias.shape[0] == weight.shape[0], "Incompatible dimensions in between weight and bias" - assert weight.is_contiguous() - - M, K = x_.shape - N, K = weight.shape - - outputs = torch.empty((M, N), device=x.device, dtype=x.dtype) - - # 1D launch kernel where each block gets its own program. - grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) # noqa - - kernel_fma[grid]( - outputs, - act_inputs, - x_, - weight, # data ptrs - bias if bias is not None else x, # auto skip bias if not present - M, # shapes - N, - K, - M // 32, # key for triton cache (limit number of compilations) - N // 32, - K // 32, - stride_om=outputs.stride(0), # strides - stride_on=outputs.stride(1), - stride_im=x_.stride(0), - stride_ik=x_.stride(1), - stride_wn=weight.stride(0), - stride_wk=weight.stride(1), - BIAS=bias is not None, # optional fused bias - SAVE_ACT_INPUTS=act_inputs is not None, # optional save activation inputs - ACTIVATION=activation if not None else x, # optional fused activation - GROUP_M=8, # speed optimization: group the programs - ) - - outputs = outputs if x.ndim == 2 else outputs.reshape(x.shape[0], -1, N) - ctx.save_for_backward(weight, bias, x) - return outputs - - -def linear_layer_fwd( - x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor], - activation="", - act_inputs: Optional[torch.Tensor] = None, -) -> torch.Tensor: - return LinearLayerFwd.apply(x, weight, bias, activation, act_inputs) - - @triton.autotune( configs=[ triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8), @@ -348,6 +268,7 @@ def kernel_bwd( EVEN_K: tl.constexpr, ACTIVATION: tl.constexpr, ): + assert ACTIVATION in ["", "tanh", "gelu", "fast_gelu", "relu"], f"{ACTIVATION} is not supported" pid = tl.program_id(axis=0) grid_m = (M + BLOCK_M - 1) // BLOCK_M @@ -387,7 +308,7 @@ def kernel_bwd( B += BLOCK_K * stride_wk # optional: fused activation (while the data is in shared memory) - if ACTIVATION != "id": + if ACTIVATION not in ["", "id"]: act_in_ptrs = ACT_INPUT + ram[:, None] * stride_om + rbn[None, :] * stride_on act_input = tl.load(act_in_ptrs).to(acc.dtype) if ACTIVATION == "tanh": @@ -409,7 +330,75 @@ def kernel_bwd( tl.store(C, acc, mask=mask) -class LinearLayerBwd(torch.autograd.Function): +class LinearLayer(torch.autograd.Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward( + ctx: FunctionCtx, + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + activation: str, + act_inputs: Optional[torch.Tensor], + ) -> torch.Tensor: + """ + Compute e = activation(x @ weight + bias). + This wrapper kicks the `kernel_fma` Triton kernel + :param ctx: context for autograd + :param x: input tensor + :param weight: weight matrix + :param bias: an optional bias tensor + :param activation: Activation name. Needs to be a Triton kernel. + :param act_inputs: an optional tensor to save the activation inputs (for backward) + :return: result tensor + """ + x_ = x if x.ndim == 2 else x.flatten(0, 1) + + assert x.dtype == weight.dtype, f"Input and weight must have the same dtype, got {x.dtype} and {weight.dtype}" + if bias is not None: + assert x.dtype == bias.dtype, f"Input and bias must have the same dtype, got {x.dtype} and {bias.dtype}" + assert x_.shape[1] == weight.shape[1], f"Incompatible dimensions: {x_.shape} - {weight.shape}" + + assert bias is None or bias.is_contiguous() + assert bias is None or bias.shape[0] == weight.shape[0], "Incompatible dimensions in between weight and bias" + assert weight.is_contiguous() + + M, K = x_.shape + N, K = weight.shape + + outputs = torch.empty((M, N), device=x.device, dtype=x.dtype) + + # 1D launch kernel where each block gets its own program. + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) # noqa + + kernel_fma[grid]( + outputs, + act_inputs, + x_, + weight, # data ptrs + bias if bias is not None else x, # auto skip bias if not present + M, # shapes + N, + K, + M // 32, # key for triton cache (limit number of compilations) + N // 32, + K // 32, + stride_om=outputs.stride(0), # strides + stride_on=outputs.stride(1), + stride_im=x_.stride(0), + stride_ik=x_.stride(1), + stride_wn=weight.stride(0), + stride_wk=weight.stride(1), + BIAS=bias is not None, # optional fused bias + SAVE_ACT_INPUTS=act_inputs is not None, # optional save activation inputs + ACTIVATION=activation if not None else x, # optional fused activation + GROUP_M=8, # speed optimization: group the programs + ) + + outputs = outputs if x.ndim == 2 else outputs.reshape(x.shape[0], -1, N) + ctx.save_for_backward(weight, bias, x) + return outputs + @staticmethod @custom_bwd def backward( @@ -429,7 +418,7 @@ def backward( :param act_inputs: an optional tensor to save the activation inputs (for backward) :return: result tensor """ - assert activation in ["id", "gelu", "gelu_approx", "squared_relu"] + assert activation in ["", "id", "gelu", "gelu_approx", "squared_relu"] batch_shape, n = grad_outputs.shape[:-1], grad_outputs.shape[-1] batch_dim = batch_shape.numel() @@ -482,10 +471,20 @@ def backward( return grad_input.reshape(*batch_shape, grad_input.shape[-1]) +def linear_layer_fwd( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + activation="", + act_inputs: Optional[torch.Tensor] = None, +) -> torch.Tensor: + return LinearLayer.forward(x, weight, bias, activation, act_inputs) + + def linear_layer_bwd( grad_output: torch.Tensor, weight: torch.Tensor, activation="", act_inputs: Optional[torch.Tensor] = None, ) -> torch.Tensor: - return LinearLayerBwd.apply(grad_output, weight, activation, act_inputs) + return LinearLayer.backward(grad_output, weight, activation, act_inputs) diff --git a/test/test_activations.py b/test/test_activations.py deleted file mode 100644 index e69de29b..00000000 diff --git a/test/test_linear_layer.py b/test/test_linear_layer.py index d284d8ed..a8ebab3a 100644 --- a/test/test_linear_layer.py +++ b/test/test_linear_layer.py @@ -37,7 +37,7 @@ def get_pytorch_activation(activation: str) -> Callable: raise ValueError(f"Unknown activation: {activation}") -implementations = { +forward_implementations = { "pytorch": lambda weight, bias, activation: lambda x: get_pytorch_activation(activation)( torch.nn.functional.linear(x, weight, bias) ), @@ -57,7 +57,7 @@ def get_pytorch_activation(activation: str) -> Callable: @pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["fp32", "fp16"]) @pytest.mark.parametrize("cuda_graphs", [False, True], ids=["no_cuda_graphs", "cuda_graphs"]) @pytest.mark.parametrize("implementation", ["triton", "pytorch"]) -def test_benchmark( +def test_benchmark_linear_forward( benchmark, implementation: str, cuda_graphs: bool, @@ -88,7 +88,7 @@ def test_benchmark( layer_bias = layer_bias.to(dtype=dtype) x = x.to(dtype=dtype) - fn = implementations[implementation](layer_weight, layer_bias, activation) + fn = forward_implementations[implementation](layer_weight, layer_bias, activation) if cuda_graphs: run = cuda_graphs_wrapper(model=fn, inputs=[x]) # CUDA graphs wraps output in a tuple From 1cbfadf131c766fd5aea528aba0c3466999dd084 Mon Sep 17 00:00:00 2001 From: ayoub-louati Date: Tue, 10 Jan 2023 23:42:00 +0100 Subject: [PATCH 06/23] feat: fix backward kernel call --- src/kernl/implementations/linear_layer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/kernl/implementations/linear_layer.py b/src/kernl/implementations/linear_layer.py index 66f58dfb..a716ba91 100644 --- a/src/kernl/implementations/linear_layer.py +++ b/src/kernl/implementations/linear_layer.py @@ -458,12 +458,12 @@ def backward( M // 32, # key for triton cache (limit number of compilations) N // 32, K // 32, - stride_cm=grad_input.stride(0), # strides - # stride_cn=grad_input.stride(1), - stride_am=grad_output_reshaped.stride(0), - stride_ak=grad_output_reshaped.stride(1), - stride_bk=weight.stride(0), - stride_bn=weight.stride(1), + stride_om=grad_input.stride(0), # strides + stride_on=grad_input.stride(1), + stride_im=grad_output_reshaped.stride(0), + stride_ik=grad_output_reshaped.stride(1), + stride_wn=weight.stride(0), + stride_wk=weight.stride(1), ACTIVATION=activation, # optional fused activation GROUP_M=8, # speed optimization: group the programs ) From c4fb7716186e3e18ab63fd0672901a7d67b19380 Mon Sep 17 00:00:00 2001 From: ayoub-louati Date: Wed, 11 Jan 2023 09:59:20 +0100 Subject: [PATCH 07/23] feat: add backward benchmark --- src/kernl/implementations/linear_layer.py | 21 ++++++- test/test_linear_layer.py | 76 ++++++++++++++++++++++- 2 files changed, 91 insertions(+), 6 deletions(-) diff --git a/src/kernl/implementations/linear_layer.py b/src/kernl/implementations/linear_layer.py index a716ba91..64fcc33d 100644 --- a/src/kernl/implementations/linear_layer.py +++ b/src/kernl/implementations/linear_layer.py @@ -330,7 +330,7 @@ def kernel_bwd( tl.store(C, acc, mask=mask) -class LinearLayer(torch.autograd.Function): +class LinearLayerFwd(torch.autograd.Function): @staticmethod @custom_fwd(cast_inputs=torch.float16) def forward( @@ -399,6 +399,20 @@ def forward( ctx.save_for_backward(weight, bias, x) return outputs + +class LinearLayerBwd(torch.autograd.Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward( + ctx: FunctionCtx, + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + activation: str, + act_inputs: Optional[torch.Tensor], + ) -> torch.Tensor: + LinearLayerFwd.forward(ctx, x, weight, bias, activation, act_inputs) + @staticmethod @custom_bwd def backward( @@ -478,13 +492,14 @@ def linear_layer_fwd( activation="", act_inputs: Optional[torch.Tensor] = None, ) -> torch.Tensor: - return LinearLayer.forward(x, weight, bias, activation, act_inputs) + return LinearLayerFwd.apply(x, weight, bias, activation, act_inputs) def linear_layer_bwd( grad_output: torch.Tensor, weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, activation="", act_inputs: Optional[torch.Tensor] = None, ) -> torch.Tensor: - return LinearLayer.backward(grad_output, weight, activation, act_inputs) + return LinearLayerBwd.apply(grad_output, weight, bias, activation, act_inputs) diff --git a/test/test_linear_layer.py b/test/test_linear_layer.py index a8ebab3a..95803b9b 100644 --- a/test/test_linear_layer.py +++ b/test/test_linear_layer.py @@ -21,7 +21,7 @@ from conftest import assert_all_close, set_seed from kernl.implementations.cuda_graph import cuda_graphs_wrapper -from kernl.implementations.linear_layer import linear_layer_fwd +from kernl.implementations.linear_layer import linear_layer_bwd, linear_layer_fwd def get_pytorch_activation(activation: str) -> Callable: @@ -41,13 +41,16 @@ def get_pytorch_activation(activation: str) -> Callable: "pytorch": lambda weight, bias, activation: lambda x: get_pytorch_activation(activation)( torch.nn.functional.linear(x, weight, bias) ), - "triton": lambda weight, bias, activation: lambda x: linear_layer_fwd(x, weight, bias, activation), + "triton": lambda weight, bias, activation, act_inputs: lambda x: linear_layer_fwd( + x, weight, bias, activation, act_inputs + ), } @set_seed() @pytest.mark.parametrize("contiguous", [True, False], ids=["contiguous", "non-contiguous"]) @pytest.mark.parametrize("bias", [True, False], ids=["with_bias", "no_bias"]) +@pytest.mark.parametrize("act_inputs", [True, False], ids=["with_act_inputs", "no_act_inputs"]) @pytest.mark.parametrize("activation", ["", "tanh", "gelu", "relu"], ids=["no_activation", "tanh", "gelu", "relu"]) @pytest.mark.parametrize( "shape", @@ -66,6 +69,7 @@ def test_benchmark_linear_forward( bias: bool, activation: str, contiguous: bool, + act_inputs: bool, ): batch, M, N, K = shape @@ -79,6 +83,7 @@ def test_benchmark_linear_forward( factory_kwargs = {"device": "cuda", "dtype": torch.float32, "requires_grad": False} layer_weight = torch.randn((N, K), **factory_kwargs) layer_bias = torch.randn((K,), **factory_kwargs) if bias else None + act_inputs = torch.zeros((M, N), **factory_kwargs) if act_inputs else None pytorch_layer_activation = get_pytorch_activation(activation) expected = pytorch_layer_activation(torch.nn.functional.linear(x, layer_weight, layer_bias)) @@ -86,9 +91,74 @@ def test_benchmark_linear_forward( layer_weight = layer_weight.to(dtype=dtype) if layer_bias is not None: layer_bias = layer_bias.to(dtype=dtype) + if act_inputs is not None: + act_inputs = act_inputs.to(dtype=dtype) + x = x.to(dtype=dtype) + + fn = forward_implementations[implementation](layer_weight, layer_bias, activation, act_inputs) + if cuda_graphs: + run = cuda_graphs_wrapper(model=fn, inputs=[x]) + # CUDA graphs wraps output in a tuple + fn = lambda tensor: run([tensor])[0] # noqa: E731 + + value = benchmark(fn, x) + + assert_all_close(expected, value.float(), rtol=1e-1, atol=1e-1) + + +backward_implementations = { + "pytorch": lambda weight, bias, activation: lambda x: get_pytorch_activation(activation)( + torch.nn.functional.linear(x, weight, bias).backward() + ), + "triton": lambda weight, bias, activation, act_inputs: lambda x: linear_layer_bwd( + x, weight, bias, activation, act_inputs + ), +} + + +@set_seed() +@pytest.mark.parametrize("contiguous", [True, False], ids=["contiguous", "non-contiguous"]) +@pytest.mark.parametrize("bias", [True, False], ids=["with_bias", "no_bias"]) +@pytest.mark.parametrize("activation", ["", "tanh", "gelu", "relu"], ids=["no_activation", "tanh", "gelu", "relu"]) +@pytest.mark.parametrize( + "shape", + [(1, 8, 8, 8)] + [(bs, M, 768, 768) for bs in [1, 16] for M in [8, 16, 128, 256, 512]], + ids=lambda s: "x".join(map(str, s)), +) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["fp32", "fp16"]) +@pytest.mark.parametrize("cuda_graphs", [False], ids=["no_cuda_graphs"]) +@pytest.mark.parametrize("implementation", ["triton", "pytorch"]) +def test_benchmark_linear_backward( + benchmark, + implementation: str, + cuda_graphs: bool, + shape: Tuple[int, int, int, int], + dtype: torch.dtype, + bias: bool, + activation: str, + contiguous: bool, +): + batch, M, N, K = shape + + # order of dimensions is wrong so we force contiguous call + x = torch.randn((batch, K, M), device="cuda", dtype=torch.float32, requires_grad=True) + x = x.mT + if contiguous: + x = x.contiguous() + else: + assert not x.is_contiguous() + factory_kwargs = {"device": "cuda", "dtype": torch.float32, "requires_grad": True} + layer_weight = torch.randn((N, K), **factory_kwargs) + layer_bias = torch.randn((K,), **factory_kwargs) if bias else None + linear_output = torch.nn.functional.linear(x, layer_weight, layer_bias) + expected = linear_output.backward(linear_output, retain_graph=True) + + # tensors casting + layer_weight = layer_weight.to(dtype=dtype) + act_inputs = torch.zeros((M, N), **factory_kwargs) x = x.to(dtype=dtype) - fn = forward_implementations[implementation](layer_weight, layer_bias, activation) + fn = backward_implementations[implementation](layer_weight, layer_bias, activation, act_inputs) if cuda_graphs: run = cuda_graphs_wrapper(model=fn, inputs=[x]) # CUDA graphs wraps output in a tuple From ed5dc3350a0997340b21439d57a4496cff75e9e6 Mon Sep 17 00:00:00 2001 From: ayoub-louati Date: Wed, 11 Jan 2023 14:50:13 +0100 Subject: [PATCH 08/23] feat: define linear_layer class with forward and backward functions --- src/kernl/implementations/linear_layer.py | 30 +++-------------------- src/kernl/optimizer/linear.py | 4 +-- test/test_linear_layer.py | 12 ++++----- 3 files changed, 11 insertions(+), 35 deletions(-) diff --git a/src/kernl/implementations/linear_layer.py b/src/kernl/implementations/linear_layer.py index 64fcc33d..1840ab3a 100644 --- a/src/kernl/implementations/linear_layer.py +++ b/src/kernl/implementations/linear_layer.py @@ -330,7 +330,7 @@ def kernel_bwd( tl.store(C, acc, mask=mask) -class LinearLayerFwd(torch.autograd.Function): +class LinearLayer(torch.autograd.Function): @staticmethod @custom_fwd(cast_inputs=torch.float16) def forward( @@ -399,20 +399,6 @@ def forward( ctx.save_for_backward(weight, bias, x) return outputs - -class LinearLayerBwd(torch.autograd.Function): - @staticmethod - @custom_fwd(cast_inputs=torch.float16) - def forward( - ctx: FunctionCtx, - x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor], - activation: str, - act_inputs: Optional[torch.Tensor], - ) -> torch.Tensor: - LinearLayerFwd.forward(ctx, x, weight, bias, activation, act_inputs) - @staticmethod @custom_bwd def backward( @@ -485,21 +471,11 @@ def backward( return grad_input.reshape(*batch_shape, grad_input.shape[-1]) -def linear_layer_fwd( +def linear_layer( x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], activation="", act_inputs: Optional[torch.Tensor] = None, ) -> torch.Tensor: - return LinearLayerFwd.apply(x, weight, bias, activation, act_inputs) - - -def linear_layer_bwd( - grad_output: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None, - activation="", - act_inputs: Optional[torch.Tensor] = None, -) -> torch.Tensor: - return LinearLayerBwd.apply(grad_output, weight, bias, activation, act_inputs) + return LinearLayer.apply(x, weight, bias, activation, act_inputs) diff --git a/src/kernl/optimizer/linear.py b/src/kernl/optimizer/linear.py index f9c6406c..e9e6734a 100644 --- a/src/kernl/optimizer/linear.py +++ b/src/kernl/optimizer/linear.py @@ -17,7 +17,7 @@ import torch -from kernl.implementations.linear_layer import linear_layer_fwd +from kernl.implementations.linear_layer import linear_layer from kernl.utils.extended_matcher import replace_pattern @@ -35,7 +35,7 @@ def linear_wrapper_functional(v: torch.Tensor, weight: torch.Tensor, bias: torch if bias is not None and bias.dtype == torch.float32: bias.data = bias.data.half() - return linear_layer_fwd(v, weight, bias, activation=activation) + return linear_layer(v, weight, bias, activation=activation) torch.fx.wrap("linear_wrapper_functional") diff --git a/test/test_linear_layer.py b/test/test_linear_layer.py index 95803b9b..c16c65ad 100644 --- a/test/test_linear_layer.py +++ b/test/test_linear_layer.py @@ -21,7 +21,7 @@ from conftest import assert_all_close, set_seed from kernl.implementations.cuda_graph import cuda_graphs_wrapper -from kernl.implementations.linear_layer import linear_layer_bwd, linear_layer_fwd +from kernl.implementations.linear_layer import linear_layer def get_pytorch_activation(activation: str) -> Callable: @@ -41,7 +41,7 @@ def get_pytorch_activation(activation: str) -> Callable: "pytorch": lambda weight, bias, activation: lambda x: get_pytorch_activation(activation)( torch.nn.functional.linear(x, weight, bias) ), - "triton": lambda weight, bias, activation, act_inputs: lambda x: linear_layer_fwd( + "triton": lambda weight, bias, activation, act_inputs: lambda x: linear_layer( x, weight, bias, activation, act_inputs ), } @@ -110,9 +110,9 @@ def test_benchmark_linear_forward( "pytorch": lambda weight, bias, activation: lambda x: get_pytorch_activation(activation)( torch.nn.functional.linear(x, weight, bias).backward() ), - "triton": lambda weight, bias, activation, act_inputs: lambda x: linear_layer_bwd( + "triton": lambda weight, bias, activation, act_inputs: lambda x: linear_layer( x, weight, bias, activation, act_inputs - ), + ).backward(), } @@ -151,7 +151,7 @@ def test_benchmark_linear_backward( layer_weight = torch.randn((N, K), **factory_kwargs) layer_bias = torch.randn((K,), **factory_kwargs) if bias else None linear_output = torch.nn.functional.linear(x, layer_weight, layer_bias) - expected = linear_output.backward(linear_output, retain_graph=True) + linear_output.backward(linear_output, retain_graph=True) # tensors casting layer_weight = layer_weight.to(dtype=dtype) @@ -166,4 +166,4 @@ def test_benchmark_linear_backward( value = benchmark(fn, x) - assert_all_close(expected, value.float(), rtol=1e-1, atol=1e-1) + assert_all_close(linear_layer, value.float(), rtol=1e-1, atol=1e-1) From 738c70fa49420085eccfb16aaf0baeee3b42223e Mon Sep 17 00:00:00 2001 From: ayoub-louati Date: Wed, 11 Jan 2023 17:12:38 +0100 Subject: [PATCH 09/23] feat: use FunctionCtx to pass tensors from forward to backward function --- src/kernl/implementations/attention.py | 2 +- src/kernl/implementations/linear_layer.py | 13 +++++-------- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/src/kernl/implementations/attention.py b/src/kernl/implementations/attention.py index b9a310f8..1b79912b 100644 --- a/src/kernl/implementations/attention.py +++ b/src/kernl/implementations/attention.py @@ -66,7 +66,7 @@ def closest_power_of_2(n: int, min_range: int = 16, max_range: int = 128) -> Lis n = max(min(n, max_range), min_range) min_range = math.floor(math.log2(n - 1)) max_range = math.ceil(math.log2(n + 1)) - ranges = [2**i for i in range(min_range, max_range + 1)] + ranges = [2 ** i for i in range(min_range, max_range + 1)] return ranges diff --git a/src/kernl/implementations/linear_layer.py b/src/kernl/implementations/linear_layer.py index 1840ab3a..2625eaf4 100644 --- a/src/kernl/implementations/linear_layer.py +++ b/src/kernl/implementations/linear_layer.py @@ -363,6 +363,7 @@ def forward( assert bias is None or bias.shape[0] == weight.shape[0], "Incompatible dimensions in between weight and bias" assert weight.is_contiguous() + ctx.activation = activation M, K = x_.shape N, K = weight.shape @@ -397,15 +398,13 @@ def forward( outputs = outputs if x.ndim == 2 else outputs.reshape(x.shape[0], -1, N) ctx.save_for_backward(weight, bias, x) - return outputs + return tuple[outputs, act_inputs] @staticmethod @custom_bwd def backward( ctx: FunctionCtx, grad_outputs: torch.Tensor, - weight: torch.Tensor, - activation: str = "id", act_inputs: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ @@ -413,13 +412,11 @@ def backward( This wrapper kicks the `kernel_fwd` Triton kernel :param ctx: context for autograd :param grad_outputs: input tensor - :param weight: weight matrix :param activation: Activation name. Needs to be a Triton kernel. :param act_inputs: an optional tensor to save the activation inputs (for backward) :return: result tensor """ - assert activation in ["", "id", "gelu", "gelu_approx", "squared_relu"] - + weight, bias, x = ctx.saved_tensors batch_shape, n = grad_outputs.shape[:-1], grad_outputs.shape[-1] batch_dim = batch_shape.numel() grad_output_reshaped = grad_outputs.reshape(batch_dim, n) @@ -435,7 +432,7 @@ def backward( assert ( grad_output_reshaped.shape[1] == weight.shape[0] ), f"Incompatible dimensions: {grad_output_reshaped.shape} - {weight.shape}" - if activation != "id": + if ctx.activation != "id": assert act_inputs is not None, f"act_input is required for activation {activation}" # M, N, K in bwd are different from M, N, K in fwd @@ -464,7 +461,7 @@ def backward( stride_ik=grad_output_reshaped.stride(1), stride_wn=weight.stride(0), stride_wk=weight.stride(1), - ACTIVATION=activation, # optional fused activation + ACTIVATION=ctx.activation, # optional fused activation GROUP_M=8, # speed optimization: group the programs ) From 73cc6e25983e81b1ddbb319349d7b35f6acd97d5 Mon Sep 17 00:00:00 2001 From: ayoub-louati Date: Wed, 11 Jan 2023 17:58:21 +0100 Subject: [PATCH 10/23] feat: fix forward benchmark --- src/kernl/implementations/linear_layer.py | 4 +--- test/test_linear_layer.py | 9 ++------- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/src/kernl/implementations/linear_layer.py b/src/kernl/implementations/linear_layer.py index 2625eaf4..c04d8fb9 100644 --- a/src/kernl/implementations/linear_layer.py +++ b/src/kernl/implementations/linear_layer.py @@ -138,7 +138,6 @@ def kernel_fma( This kernel will consolidate over K """ - assert ACTIVATION in ["", "tanh", "gelu", "fast_gelu", "relu"], f"{ACTIVATION} is not supported" pid = tl.program_id(axis=0) grid_m = (M + BLOCK_M - 1) // BLOCK_M @@ -268,7 +267,6 @@ def kernel_bwd( EVEN_K: tl.constexpr, ACTIVATION: tl.constexpr, ): - assert ACTIVATION in ["", "tanh", "gelu", "fast_gelu", "relu"], f"{ACTIVATION} is not supported" pid = tl.program_id(axis=0) grid_m = (M + BLOCK_M - 1) // BLOCK_M @@ -398,7 +396,7 @@ def forward( outputs = outputs if x.ndim == 2 else outputs.reshape(x.shape[0], -1, N) ctx.save_for_backward(weight, bias, x) - return tuple[outputs, act_inputs] + return outputs, act_inputs @staticmethod @custom_bwd diff --git a/test/test_linear_layer.py b/test/test_linear_layer.py index c16c65ad..285dc371 100644 --- a/test/test_linear_layer.py +++ b/test/test_linear_layer.py @@ -50,7 +50,6 @@ def get_pytorch_activation(activation: str) -> Callable: @set_seed() @pytest.mark.parametrize("contiguous", [True, False], ids=["contiguous", "non-contiguous"]) @pytest.mark.parametrize("bias", [True, False], ids=["with_bias", "no_bias"]) -@pytest.mark.parametrize("act_inputs", [True, False], ids=["with_act_inputs", "no_act_inputs"]) @pytest.mark.parametrize("activation", ["", "tanh", "gelu", "relu"], ids=["no_activation", "tanh", "gelu", "relu"]) @pytest.mark.parametrize( "shape", @@ -69,7 +68,6 @@ def test_benchmark_linear_forward( bias: bool, activation: str, contiguous: bool, - act_inputs: bool, ): batch, M, N, K = shape @@ -83,7 +81,6 @@ def test_benchmark_linear_forward( factory_kwargs = {"device": "cuda", "dtype": torch.float32, "requires_grad": False} layer_weight = torch.randn((N, K), **factory_kwargs) layer_bias = torch.randn((K,), **factory_kwargs) if bias else None - act_inputs = torch.zeros((M, N), **factory_kwargs) if act_inputs else None pytorch_layer_activation = get_pytorch_activation(activation) expected = pytorch_layer_activation(torch.nn.functional.linear(x, layer_weight, layer_bias)) @@ -91,17 +88,15 @@ def test_benchmark_linear_forward( layer_weight = layer_weight.to(dtype=dtype) if layer_bias is not None: layer_bias = layer_bias.to(dtype=dtype) - if act_inputs is not None: - act_inputs = act_inputs.to(dtype=dtype) x = x.to(dtype=dtype) - fn = forward_implementations[implementation](layer_weight, layer_bias, activation, act_inputs) + fn = forward_implementations[implementation](layer_weight, layer_bias, activation, None) if cuda_graphs: run = cuda_graphs_wrapper(model=fn, inputs=[x]) # CUDA graphs wraps output in a tuple fn = lambda tensor: run([tensor])[0] # noqa: E731 - value = benchmark(fn, x) + value = benchmark(fn, x)[0] assert_all_close(expected, value.float(), rtol=1e-1, atol=1e-1) From 74e3643eaf3e0a758148c898a6073245b1a94c61 Mon Sep 17 00:00:00 2001 From: ayoub-louati Date: Fri, 13 Jan 2023 11:05:02 +0100 Subject: [PATCH 11/23] feat: fix benchmark backward --- src/kernl/implementations/linear_layer.py | 23 ++++++++---------- test/test_linear_layer.py | 29 +++++++++++------------ 2 files changed, 24 insertions(+), 28 deletions(-) diff --git a/src/kernl/implementations/linear_layer.py b/src/kernl/implementations/linear_layer.py index c04d8fb9..c8b2784b 100644 --- a/src/kernl/implementations/linear_layer.py +++ b/src/kernl/implementations/linear_layer.py @@ -16,13 +16,13 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. -from typing import Optional +from typing import Any, Optional import torch import triton import triton.language as tl from torch.autograd.function import FunctionCtx -from torch.cuda.amp import custom_bwd, custom_fwd +from torch.cuda.amp import custom_fwd from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time from kernl.implementations import activation_func @@ -306,7 +306,7 @@ def kernel_bwd( B += BLOCK_K * stride_wk # optional: fused activation (while the data is in shared memory) - if ACTIVATION not in ["", "id"]: + if ACTIVATION != "": act_in_ptrs = ACT_INPUT + ram[:, None] * stride_om + rbn[None, :] * stride_on act_input = tl.load(act_in_ptrs).to(acc.dtype) if ACTIVATION == "tanh": @@ -365,7 +365,7 @@ def forward( M, K = x_.shape N, K = weight.shape - outputs = torch.empty((M, N), device=x.device, dtype=x.dtype) + outputs = torch.empty((M, N), device=x.device, dtype=x.dtype, requires_grad=True) # 1D launch kernel where each block gets its own program. grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) # noqa @@ -396,25 +396,22 @@ def forward( outputs = outputs if x.ndim == 2 else outputs.reshape(x.shape[0], -1, N) ctx.save_for_backward(weight, bias, x) - return outputs, act_inputs + return outputs @staticmethod - @custom_bwd def backward( ctx: FunctionCtx, - grad_outputs: torch.Tensor, - act_inputs: Optional[torch.Tensor] = None, + *grad_outputs: Any, ) -> torch.Tensor: """ Compute e = activation(grad_output @ weight + bias). This wrapper kicks the `kernel_fwd` Triton kernel :param ctx: context for autograd :param grad_outputs: input tensor - :param activation: Activation name. Needs to be a Triton kernel. - :param act_inputs: an optional tensor to save the activation inputs (for backward) :return: result tensor """ - weight, bias, x = ctx.saved_tensors + weight, bias, act_inputs = ctx.saved_tensors + grad_outputs = grad_outputs[0] batch_shape, n = grad_outputs.shape[:-1], grad_outputs.shape[-1] batch_dim = batch_shape.numel() grad_output_reshaped = grad_outputs.reshape(batch_dim, n) @@ -431,7 +428,7 @@ def backward( grad_output_reshaped.shape[1] == weight.shape[0] ), f"Incompatible dimensions: {grad_output_reshaped.shape} - {weight.shape}" if ctx.activation != "id": - assert act_inputs is not None, f"act_input is required for activation {activation}" + assert act_inputs is not None, f"act_input is required for activation {ctx.activation}" # M, N, K in bwd are different from M, N, K in fwd M, K = grad_output_reshaped.shape @@ -463,7 +460,7 @@ def backward( GROUP_M=8, # speed optimization: group the programs ) - return grad_input.reshape(*batch_shape, grad_input.shape[-1]) + return grad_input.reshape(*batch_shape, grad_input.shape[-1]), None, None, None, None def linear_layer( diff --git a/test/test_linear_layer.py b/test/test_linear_layer.py index 285dc371..369b06f5 100644 --- a/test/test_linear_layer.py +++ b/test/test_linear_layer.py @@ -41,9 +41,7 @@ def get_pytorch_activation(activation: str) -> Callable: "pytorch": lambda weight, bias, activation: lambda x: get_pytorch_activation(activation)( torch.nn.functional.linear(x, weight, bias) ), - "triton": lambda weight, bias, activation, act_inputs: lambda x: linear_layer( - x, weight, bias, activation, act_inputs - ), + "triton": lambda weight, bias, activation: lambda x: linear_layer(x, weight, bias, activation, None), } @@ -90,24 +88,24 @@ def test_benchmark_linear_forward( layer_bias = layer_bias.to(dtype=dtype) x = x.to(dtype=dtype) - fn = forward_implementations[implementation](layer_weight, layer_bias, activation, None) + fn = forward_implementations[implementation](layer_weight, layer_bias, activation) if cuda_graphs: run = cuda_graphs_wrapper(model=fn, inputs=[x]) # CUDA graphs wraps output in a tuple fn = lambda tensor: run([tensor])[0] # noqa: E731 - value = benchmark(fn, x)[0] + value = benchmark(fn, x) assert_all_close(expected, value.float(), rtol=1e-1, atol=1e-1) backward_implementations = { - "pytorch": lambda weight, bias, activation: lambda x: get_pytorch_activation(activation)( - torch.nn.functional.linear(x, weight, bias).backward() + "pytorch": lambda weight, bias, activation, grad: lambda x: get_pytorch_activation(activation)( + torch.nn.functional.linear(x, weight, bias).backward(gradient=grad, retain_graph=True) ), - "triton": lambda weight, bias, activation, act_inputs: lambda x: linear_layer( + "triton": lambda weight, bias, activation, act_inputs, grad: lambda x: linear_layer( x, weight, bias, activation, act_inputs - ).backward(), + ).backward(gradient=grad, retain_graph=True), } @@ -145,20 +143,21 @@ def test_benchmark_linear_backward( factory_kwargs = {"device": "cuda", "dtype": torch.float32, "requires_grad": True} layer_weight = torch.randn((N, K), **factory_kwargs) layer_bias = torch.randn((K,), **factory_kwargs) if bias else None - linear_output = torch.nn.functional.linear(x, layer_weight, layer_bias) - linear_output.backward(linear_output, retain_graph=True) + grad = torch.randn((batch, M, K), **factory_kwargs) + fwd_output = torch.nn.functional.linear(x, layer_weight, layer_bias) + fwd_output.backward(grad, retain_graph=True) # tensors casting layer_weight = layer_weight.to(dtype=dtype) - act_inputs = torch.zeros((M, N), **factory_kwargs) + act_inputs = torch.ones((M, N), **factory_kwargs) x = x.to(dtype=dtype) - fn = backward_implementations[implementation](layer_weight, layer_bias, activation, act_inputs) + fn = backward_implementations[implementation](layer_weight, layer_bias, activation, act_inputs, grad) if cuda_graphs: run = cuda_graphs_wrapper(model=fn, inputs=[x]) # CUDA graphs wraps output in a tuple fn = lambda tensor: run([tensor])[0] # noqa: E731 - value = benchmark(fn, x) + _ = benchmark(fn, x) - assert_all_close(linear_layer, value.float(), rtol=1e-1, atol=1e-1) + # assert_all_close(fwd_output, value, rtol=1e-1, atol=1e-1) From c17002f8630ae32f359692f15d04db5e914af1d6 Mon Sep 17 00:00:00 2001 From: ayoub-louati Date: Fri, 13 Jan 2023 11:07:11 +0100 Subject: [PATCH 12/23] feat: fix benchmark backward --- src/kernl/implementations/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/kernl/implementations/attention.py b/src/kernl/implementations/attention.py index 1b79912b..b9a310f8 100644 --- a/src/kernl/implementations/attention.py +++ b/src/kernl/implementations/attention.py @@ -66,7 +66,7 @@ def closest_power_of_2(n: int, min_range: int = 16, max_range: int = 128) -> Lis n = max(min(n, max_range), min_range) min_range = math.floor(math.log2(n - 1)) max_range = math.ceil(math.log2(n + 1)) - ranges = [2 ** i for i in range(min_range, max_range + 1)] + ranges = [2**i for i in range(min_range, max_range + 1)] return ranges From ceb40be5b6f72fde15b529cc4867df138eda04de Mon Sep 17 00:00:00 2001 From: ayoub-louati Date: Tue, 17 Jan 2023 17:01:54 +0100 Subject: [PATCH 13/23] feat: add backward benchmark --- src/kernl/implementations/linear_layer.py | 9 ++-- test/test_linear_layer.py | 55 ++++++++++++----------- 2 files changed, 36 insertions(+), 28 deletions(-) diff --git a/src/kernl/implementations/linear_layer.py b/src/kernl/implementations/linear_layer.py index c8b2784b..3d8efb09 100644 --- a/src/kernl/implementations/linear_layer.py +++ b/src/kernl/implementations/linear_layer.py @@ -410,10 +410,13 @@ def backward( :param grad_outputs: input tensor :return: result tensor """ + print("Entering custom backward") weight, bias, act_inputs = ctx.saved_tensors + print(f"x is leaf: {act_inputs.is_leaf}") + print(f"weight: {weight} \n bias: {bias} \n act_inputs: {act_inputs}\n activation: {ctx.activation}") grad_outputs = grad_outputs[0] batch_shape, n = grad_outputs.shape[:-1], grad_outputs.shape[-1] - batch_dim = batch_shape.numel() + batch_dim = batch_shape[0] * batch_shape[1] grad_output_reshaped = grad_outputs.reshape(batch_dim, n) if grad_output_reshaped.stride(0) > 1 and grad_output_reshaped.stride(1) > 1: @@ -459,8 +462,8 @@ def backward( ACTIVATION=ctx.activation, # optional fused activation GROUP_M=8, # speed optimization: group the programs ) - - return grad_input.reshape(*batch_shape, grad_input.shape[-1]), None, None, None, None + print(f"grad input {grad_input}") + return grad_input.reshape(*batch_shape, grad_input.shape[-1]), weight, bias, None, None def linear_layer( diff --git a/test/test_linear_layer.py b/test/test_linear_layer.py index 5e6afa0b..ac00ac03 100644 --- a/test/test_linear_layer.py +++ b/test/test_linear_layer.py @@ -13,10 +13,12 @@ # limitations under the License. # +import os from typing import Callable, Tuple import pytest import torch +from torch.nn import MSELoss from conftest import assert_all_close, set_seed @@ -24,6 +26,9 @@ from kernl.optimizer.cuda_graph import cuda_graphs_wrapper +os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + + def get_pytorch_activation(activation: str) -> Callable: if activation == "gelu": return torch.nn.functional.gelu @@ -47,7 +52,6 @@ def get_pytorch_activation(activation: str) -> Callable: @set_seed() @pytest.mark.parametrize("contiguous", [True, False], ids=["contiguous", "non-contiguous"]) -@pytest.mark.parametrize("bias", [True, False], ids=["with_bias", "no_bias"]) @pytest.mark.parametrize("activation", ["", "tanh", "gelu", "relu"], ids=["no_activation", "tanh", "gelu", "relu"]) @pytest.mark.parametrize( "shape", @@ -63,7 +67,6 @@ def test_benchmark_linear_forward( cuda_graphs: bool, shape: Tuple[int, int, int, int], dtype: torch.dtype, - bias: bool, activation: str, contiguous: bool, ): @@ -78,7 +81,7 @@ def test_benchmark_linear_forward( assert not x.is_contiguous() factory_kwargs = {"device": "cuda", "dtype": torch.float32, "requires_grad": False} layer_weight = torch.randn((N, K), **factory_kwargs) - layer_bias = torch.randn((K,), **factory_kwargs) if bias else None + layer_bias = torch.randn((K,), **factory_kwargs) pytorch_layer_activation = get_pytorch_activation(activation) expected = pytorch_layer_activation(torch.nn.functional.linear(x, layer_weight, layer_bias)) @@ -100,34 +103,32 @@ def test_benchmark_linear_forward( backward_implementations = { - "pytorch": lambda weight, bias, activation, grad: lambda x: get_pytorch_activation(activation)( - torch.nn.functional.linear(x, weight, bias).backward(gradient=grad, retain_graph=True) - ), - "triton": lambda weight, bias, activation, act_inputs, grad: lambda x: linear_layer( - x, weight, bias, activation, act_inputs - ).backward(gradient=grad, retain_graph=True), + "pytorch": lambda weight, bias, activation, random_output: lambda x: MSELoss()( + get_pytorch_activation(activation)(torch.nn.functional.linear(x, weight, bias)), random_output + ).backward(), + "triton": lambda weight, bias, activation, random_output: lambda x: MSELoss()( + linear_layer(x, weight, bias, activation, None), random_output + ).backward(), } @set_seed() @pytest.mark.parametrize("contiguous", [True, False], ids=["contiguous", "non-contiguous"]) -@pytest.mark.parametrize("bias", [True, False], ids=["with_bias", "no_bias"]) @pytest.mark.parametrize("activation", ["", "tanh", "gelu", "relu"], ids=["no_activation", "tanh", "gelu", "relu"]) @pytest.mark.parametrize( "shape", [(1, 8, 8, 8)] + [(bs, M, 768, 768) for bs in [1, 16] for M in [8, 16, 128, 256, 512]], ids=lambda s: "x".join(map(str, s)), ) -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["fp32", "fp16"]) -@pytest.mark.parametrize("cuda_graphs", [False], ids=["no_cuda_graphs"]) -@pytest.mark.parametrize("implementation", ["triton", "pytorch"]) +@pytest.mark.parametrize("dtype", [torch.float32], ids=["fp32"]) +@pytest.mark.parametrize("cuda_graphs", [False, True], ids=["no_cuda_graphs", "cuda_graphs"]) +@pytest.mark.parametrize("implementation", ["triton"]) def test_benchmark_linear_backward( benchmark, implementation: str, cuda_graphs: bool, shape: Tuple[int, int, int, int], dtype: torch.dtype, - bias: bool, activation: str, contiguous: bool, ): @@ -140,24 +141,28 @@ def test_benchmark_linear_backward( x = x.contiguous() else: assert not x.is_contiguous() + x.retain_grad() # force saving grad factory_kwargs = {"device": "cuda", "dtype": torch.float32, "requires_grad": True} layer_weight = torch.randn((N, K), **factory_kwargs) - layer_bias = torch.randn((K,), **factory_kwargs) if bias else None - grad = torch.randn((batch, M, K), **factory_kwargs) - fwd_output = torch.nn.functional.linear(x, layer_weight, layer_bias) - fwd_output.backward(grad, retain_graph=True) - + layer_bias = torch.randn((K,), **factory_kwargs) # tensors casting layer_weight = layer_weight.to(dtype=dtype) - act_inputs = torch.ones((M, N), **factory_kwargs) + layer_bias = layer_bias.to(dtype=dtype) x = x.to(dtype=dtype) + x_triton = torch.clone(x) + x_triton = x_triton.to(dtype=dtype) + x_triton.retain_grad() - fn = backward_implementations[implementation](layer_weight, layer_bias, activation, act_inputs, grad) + pytorch_layer_activation = get_pytorch_activation(activation) + pytorch_fwd_output = pytorch_layer_activation(torch.nn.functional.linear(x, layer_weight, layer_bias)) + random_output = torch.randn(pytorch_fwd_output.shape, **factory_kwargs) + loss = MSELoss() + loss(pytorch_fwd_output, random_output).backward() + fn = backward_implementations[implementation](layer_weight, layer_bias, activation, random_output) if cuda_graphs: - run = cuda_graphs_wrapper(model=fn, inputs=[x]) + run = cuda_graphs_wrapper(model=fn, inputs=[x_triton]) # CUDA graphs wraps output in a tuple fn = lambda tensor: run([tensor])[0] # noqa: E731 - _ = benchmark(fn, x) - - # assert_all_close(fwd_output, value, rtol=1e-1, atol=1e-1) + _ = benchmark(fn, x_triton) + assert_all_close(x.grad, x_triton.grad, rtol=1e-1, atol=1e-1) From b29b9ed0b381c7df37d89cc60fd6452ec71b186f Mon Sep 17 00:00:00 2001 From: ayoub-louati Date: Tue, 17 Jan 2023 18:08:55 +0100 Subject: [PATCH 14/23] feat: add missing tests --- test/test_linear_layer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_linear_layer.py b/test/test_linear_layer.py index ac00ac03..69997507 100644 --- a/test/test_linear_layer.py +++ b/test/test_linear_layer.py @@ -120,9 +120,9 @@ def test_benchmark_linear_forward( [(1, 8, 8, 8)] + [(bs, M, 768, 768) for bs in [1, 16] for M in [8, 16, 128, 256, 512]], ids=lambda s: "x".join(map(str, s)), ) -@pytest.mark.parametrize("dtype", [torch.float32], ids=["fp32"]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["fp32", "fp16"]) @pytest.mark.parametrize("cuda_graphs", [False, True], ids=["no_cuda_graphs", "cuda_graphs"]) -@pytest.mark.parametrize("implementation", ["triton"]) +@pytest.mark.parametrize("implementation", ["triton", "pytorch"]) def test_benchmark_linear_backward( benchmark, implementation: str, From 0103ec70ce20e556719a53156e7adab439336f92 Mon Sep 17 00:00:00 2001 From: ayoub-louati Date: Wed, 18 Jan 2023 11:34:50 +0100 Subject: [PATCH 15/23] feat: remove debugging print in backward function + fix fp16 tests with autocast and custom_bwd --- src/kernl/implementations/linear_layer.py | 9 +++------ test/test_linear_layer.py | 19 ++++++++++++------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/src/kernl/implementations/linear_layer.py b/src/kernl/implementations/linear_layer.py index 3d8efb09..bef5fbe2 100644 --- a/src/kernl/implementations/linear_layer.py +++ b/src/kernl/implementations/linear_layer.py @@ -22,7 +22,7 @@ import triton import triton.language as tl from torch.autograd.function import FunctionCtx -from torch.cuda.amp import custom_fwd +from torch.cuda.amp import custom_bwd, custom_fwd from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time from kernl.implementations import activation_func @@ -399,6 +399,7 @@ def forward( return outputs @staticmethod + @custom_bwd def backward( ctx: FunctionCtx, *grad_outputs: Any, @@ -410,10 +411,7 @@ def backward( :param grad_outputs: input tensor :return: result tensor """ - print("Entering custom backward") weight, bias, act_inputs = ctx.saved_tensors - print(f"x is leaf: {act_inputs.is_leaf}") - print(f"weight: {weight} \n bias: {bias} \n act_inputs: {act_inputs}\n activation: {ctx.activation}") grad_outputs = grad_outputs[0] batch_shape, n = grad_outputs.shape[:-1], grad_outputs.shape[-1] batch_dim = batch_shape[0] * batch_shape[1] @@ -462,8 +460,7 @@ def backward( ACTIVATION=ctx.activation, # optional fused activation GROUP_M=8, # speed optimization: group the programs ) - print(f"grad input {grad_input}") - return grad_input.reshape(*batch_shape, grad_input.shape[-1]), weight, bias, None, None + return grad_input.reshape(*batch_shape, grad_input.shape[-1]), None, None, None, None def linear_layer( diff --git a/test/test_linear_layer.py b/test/test_linear_layer.py index 69997507..85524a76 100644 --- a/test/test_linear_layer.py +++ b/test/test_linear_layer.py @@ -18,6 +18,7 @@ import pytest import torch +from torch.cuda.amp import autocast from torch.nn import MSELoss from conftest import assert_all_close, set_seed @@ -133,31 +134,35 @@ def test_benchmark_linear_backward( contiguous: bool, ): batch, M, N, K = shape - + factory_kwargs = {"device": "cuda", "dtype": torch.float32, "requires_grad": True} # order of dimensions is wrong so we force contiguous call - x = torch.randn((batch, K, M), device="cuda", dtype=torch.float32, requires_grad=True) + x = torch.randn((batch, K, M), **factory_kwargs) x = x.mT if contiguous: x = x.contiguous() else: assert not x.is_contiguous() + x = x.to(dtype=dtype) x.retain_grad() # force saving grad - factory_kwargs = {"device": "cuda", "dtype": torch.float32, "requires_grad": True} + layer_weight = torch.randn((N, K), **factory_kwargs) - layer_bias = torch.randn((K,), **factory_kwargs) - # tensors casting layer_weight = layer_weight.to(dtype=dtype) + layer_bias = torch.randn((K,), **factory_kwargs) layer_bias = layer_bias.to(dtype=dtype) - x = x.to(dtype=dtype) + x_triton = torch.clone(x) x_triton = x_triton.to(dtype=dtype) + assert x_triton.is_contiguous() == x.is_contiguous() x_triton.retain_grad() pytorch_layer_activation = get_pytorch_activation(activation) pytorch_fwd_output = pytorch_layer_activation(torch.nn.functional.linear(x, layer_weight, layer_bias)) random_output = torch.randn(pytorch_fwd_output.shape, **factory_kwargs) + random_output = random_output.to(dtype=dtype) loss = MSELoss() - loss(pytorch_fwd_output, random_output).backward() + + with autocast(dtype=dtype): + loss(pytorch_fwd_output, random_output).backward() fn = backward_implementations[implementation](layer_weight, layer_bias, activation, random_output) if cuda_graphs: run = cuda_graphs_wrapper(model=fn, inputs=[x_triton]) From ccc7a5f3a6d54f2d235f4aec87b77bb81bf0cb51 Mon Sep 17 00:00:00 2001 From: ayoub-louati Date: Fri, 20 Jan 2023 14:42:23 +0100 Subject: [PATCH 16/23] feat: fix relu grad + fix benchmark tests --- src/kernl/implementations/activation_func.py | 2 +- src/kernl/implementations/linear_layer.py | 2 +- test/test_linear_layer.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/kernl/implementations/activation_func.py b/src/kernl/implementations/activation_func.py index 20e2415b..30e8800b 100644 --- a/src/kernl/implementations/activation_func.py +++ b/src/kernl/implementations/activation_func.py @@ -47,7 +47,7 @@ def relu(x): @triton.jit def relu_grad(x): """Relu derivative function""" - return tl.maximum(0, x) + return tl.maximum(0, 1) @triton.jit diff --git a/src/kernl/implementations/linear_layer.py b/src/kernl/implementations/linear_layer.py index bef5fbe2..b60bff89 100644 --- a/src/kernl/implementations/linear_layer.py +++ b/src/kernl/implementations/linear_layer.py @@ -414,7 +414,7 @@ def backward( weight, bias, act_inputs = ctx.saved_tensors grad_outputs = grad_outputs[0] batch_shape, n = grad_outputs.shape[:-1], grad_outputs.shape[-1] - batch_dim = batch_shape[0] * batch_shape[1] + batch_dim = batch_shape.numel() grad_output_reshaped = grad_outputs.reshape(batch_dim, n) if grad_output_reshaped.stride(0) > 1 and grad_output_reshaped.stride(1) > 1: diff --git a/test/test_linear_layer.py b/test/test_linear_layer.py index 85524a76..74b0a7b2 100644 --- a/test/test_linear_layer.py +++ b/test/test_linear_layer.py @@ -104,10 +104,10 @@ def test_benchmark_linear_forward( backward_implementations = { - "pytorch": lambda weight, bias, activation, random_output: lambda x: MSELoss()( + "pytorch": lambda weight, bias, activation, random_output: lambda x: MSELoss(reduction='sum')( get_pytorch_activation(activation)(torch.nn.functional.linear(x, weight, bias)), random_output ).backward(), - "triton": lambda weight, bias, activation, random_output: lambda x: MSELoss()( + "triton": lambda weight, bias, activation, random_output: lambda x: MSELoss(reduction='sum')( linear_layer(x, weight, bias, activation, None), random_output ).backward(), } From b646a9ad62d3e22e5780cb150c5a96cb660e65c9 Mon Sep 17 00:00:00 2001 From: ayoub-louati Date: Mon, 23 Jan 2023 09:37:25 +0100 Subject: [PATCH 17/23] feat: fix gelu grad function --- src/kernl/implementations/activation_func.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/kernl/implementations/activation_func.py b/src/kernl/implementations/activation_func.py index 30e8800b..804d8b28 100644 --- a/src/kernl/implementations/activation_func.py +++ b/src/kernl/implementations/activation_func.py @@ -23,6 +23,7 @@ sqrt2pi = math.sqrt(2.0 / math.pi) sqrt2 = math.sqrt(2.0) +sqrt1_2 = math.sqrt(1.0 / 2) gaussian_pdf_normalization = 1.0 / math.sqrt(2 * math.pi) @@ -74,6 +75,6 @@ def gelu(x): @triton.jit def gelu_grad(x): """Derivative of Gaussian Error Linear Unit (GELU)""" - cdf = 0.5 * (1.0 + tl.libdevice.erf(x * sqrt2)) + cdf = 0.5 * (1.0 + tl.libdevice.erf(x * sqrt1_2)) pdf = tl.exp(-0.5 * x * x) * gaussian_pdf_normalization return cdf + x * pdf From d162d277525047463fa84c838a005284818fa739 Mon Sep 17 00:00:00 2001 From: ayoub-louati Date: Mon, 23 Jan 2023 09:54:16 +0100 Subject: [PATCH 18/23] feat: simplify activation function --- src/kernl/implementations/activation_func.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/kernl/implementations/activation_func.py b/src/kernl/implementations/activation_func.py index 804d8b28..15d92115 100644 --- a/src/kernl/implementations/activation_func.py +++ b/src/kernl/implementations/activation_func.py @@ -23,7 +23,6 @@ sqrt2pi = math.sqrt(2.0 / math.pi) sqrt2 = math.sqrt(2.0) -sqrt1_2 = math.sqrt(1.0 / 2) gaussian_pdf_normalization = 1.0 / math.sqrt(2 * math.pi) @@ -75,6 +74,6 @@ def gelu(x): @triton.jit def gelu_grad(x): """Derivative of Gaussian Error Linear Unit (GELU)""" - cdf = 0.5 * (1.0 + tl.libdevice.erf(x * sqrt1_2)) + cdf = 0.5 * (1.0 + tl.libdevice.erf(x / sqrt2)) pdf = tl.exp(-0.5 * x * x) * gaussian_pdf_normalization return cdf + x * pdf From 2d7e501a39662fbb020ea88d294e83b031ec94df Mon Sep 17 00:00:00 2001 From: ayoub-louati Date: Mon, 23 Jan 2023 17:45:12 +0100 Subject: [PATCH 19/23] feat: fix stride -> fix unit tests --- src/kernl/implementations/linear_layer.py | 6 +++--- test/test_linear_layer.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/kernl/implementations/linear_layer.py b/src/kernl/implementations/linear_layer.py index b60bff89..a6163701 100644 --- a/src/kernl/implementations/linear_layer.py +++ b/src/kernl/implementations/linear_layer.py @@ -182,7 +182,7 @@ def kernel_fma( # optional: save the activation inputs if SAVE_ACT_INPUTS: - act_in_ptrs = ACT_INPUTS + ram[:, None] * stride_om + rbn[None, :] + act_in_ptrs = ACT_INPUTS + ram[:, None] * stride_om + rbn[None, :] * stride_on tl.store(act_in_ptrs, acc) # optional: fused activation (while the data is in shared memory) @@ -455,8 +455,8 @@ def backward( stride_on=grad_input.stride(1), stride_im=grad_output_reshaped.stride(0), stride_ik=grad_output_reshaped.stride(1), - stride_wn=weight.stride(0), - stride_wk=weight.stride(1), + stride_wn=weight.stride(1), + stride_wk=weight.stride(0), ACTIVATION=ctx.activation, # optional fused activation GROUP_M=8, # speed optimization: group the programs ) diff --git a/test/test_linear_layer.py b/test/test_linear_layer.py index 74b0a7b2..4beea695 100644 --- a/test/test_linear_layer.py +++ b/test/test_linear_layer.py @@ -104,10 +104,10 @@ def test_benchmark_linear_forward( backward_implementations = { - "pytorch": lambda weight, bias, activation, random_output: lambda x: MSELoss(reduction='sum')( + "pytorch": lambda weight, bias, activation, random_output: lambda x: MSELoss(reduction="sum")( get_pytorch_activation(activation)(torch.nn.functional.linear(x, weight, bias)), random_output ).backward(), - "triton": lambda weight, bias, activation, random_output: lambda x: MSELoss(reduction='sum')( + "triton": lambda weight, bias, activation, random_output: lambda x: MSELoss(reduction="sum")( linear_layer(x, weight, bias, activation, None), random_output ).backward(), } From 2d4691761bdd7b4152e3ab637889b06b503660e8 Mon Sep 17 00:00:00 2001 From: ayoub-louati Date: Mon, 23 Jan 2023 20:48:01 +0100 Subject: [PATCH 20/23] feat: fix tanh grad compilation --- src/kernl/implementations/activation_func.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/kernl/implementations/activation_func.py b/src/kernl/implementations/activation_func.py index 15d92115..42fbd579 100644 --- a/src/kernl/implementations/activation_func.py +++ b/src/kernl/implementations/activation_func.py @@ -35,7 +35,8 @@ def tanh(x): @triton.jit def tanh_grad(x): """Tanh derivative function""" - return 1 - tl.libdevice.pow(tl.libdevice.tanh(x), 2) + tanh_x = tanh(x) + return 1 - tanh_x * tanh_x @triton.jit From b7228e9676bd328e319634026157efcc522df377 Mon Sep 17 00:00:00 2001 From: ayoub-louati Date: Mon, 23 Jan 2023 21:25:09 +0100 Subject: [PATCH 21/23] feat: remove unnecessary modifications --- src/kernl/implementations/linear_layer.py | 2 +- test/test_linear_layer.py | 11 +++++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/kernl/implementations/linear_layer.py b/src/kernl/implementations/linear_layer.py index a6163701..0475bba6 100644 --- a/src/kernl/implementations/linear_layer.py +++ b/src/kernl/implementations/linear_layer.py @@ -365,7 +365,7 @@ def forward( M, K = x_.shape N, K = weight.shape - outputs = torch.empty((M, N), device=x.device, dtype=x.dtype, requires_grad=True) + outputs = torch.empty((M, N), device=x.device, dtype=x.dtype) # 1D launch kernel where each block gets its own program. grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) # noqa diff --git a/test/test_linear_layer.py b/test/test_linear_layer.py index 4beea695..93603910 100644 --- a/test/test_linear_layer.py +++ b/test/test_linear_layer.py @@ -27,9 +27,6 @@ from kernl.optimizer.cuda_graph import cuda_graphs_wrapper -os.environ["CUDA_LAUNCH_BLOCKING"] = "1" - - def get_pytorch_activation(activation: str) -> Callable: if activation == "gelu": return torch.nn.functional.gelu @@ -47,12 +44,13 @@ def get_pytorch_activation(activation: str) -> Callable: "pytorch": lambda weight, bias, activation: lambda x: get_pytorch_activation(activation)( torch.nn.functional.linear(x, weight, bias) ), - "triton": lambda weight, bias, activation: lambda x: linear_layer(x, weight, bias, activation, None), + "triton": lambda weight, bias, activation: lambda x: linear_layer(x, weight, bias, activation), } @set_seed() @pytest.mark.parametrize("contiguous", [True, False], ids=["contiguous", "non-contiguous"]) +@pytest.mark.parametrize("bias", [True, False], ids=["with_bias", "no_bias"]) @pytest.mark.parametrize("activation", ["", "tanh", "gelu", "relu"], ids=["no_activation", "tanh", "gelu", "relu"]) @pytest.mark.parametrize( "shape", @@ -68,6 +66,7 @@ def test_benchmark_linear_forward( cuda_graphs: bool, shape: Tuple[int, int, int, int], dtype: torch.dtype, + bias: bool, activation: str, contiguous: bool, ): @@ -82,7 +81,7 @@ def test_benchmark_linear_forward( assert not x.is_contiguous() factory_kwargs = {"device": "cuda", "dtype": torch.float32, "requires_grad": False} layer_weight = torch.randn((N, K), **factory_kwargs) - layer_bias = torch.randn((K,), **factory_kwargs) + layer_bias = torch.randn((K,), **factory_kwargs) if bias else None pytorch_layer_activation = get_pytorch_activation(activation) expected = pytorch_layer_activation(torch.nn.functional.linear(x, layer_weight, layer_bias)) @@ -122,7 +121,7 @@ def test_benchmark_linear_forward( ids=lambda s: "x".join(map(str, s)), ) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["fp32", "fp16"]) -@pytest.mark.parametrize("cuda_graphs", [False, True], ids=["no_cuda_graphs", "cuda_graphs"]) +@pytest.mark.parametrize("cuda_graphs", [False], ids=["no_cuda_graphs"]) @pytest.mark.parametrize("implementation", ["triton", "pytorch"]) def test_benchmark_linear_backward( benchmark, From 620be3a93316fdce2f07e0000e5c0d9bf2648afa Mon Sep 17 00:00:00 2001 From: ayoub-louati Date: Wed, 25 Jan 2023 00:35:11 +0100 Subject: [PATCH 22/23] feat: remove unused import --- test/test_linear_layer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_linear_layer.py b/test/test_linear_layer.py index 93603910..93574b8e 100644 --- a/test/test_linear_layer.py +++ b/test/test_linear_layer.py @@ -13,7 +13,6 @@ # limitations under the License. # -import os from typing import Callable, Tuple import pytest From 5273fab3560b2a09b7ecec0c2dea5bfbedf9380c Mon Sep 17 00:00:00 2001 From: ayoub-louati Date: Fri, 27 Jan 2023 12:06:59 +0100 Subject: [PATCH 23/23] feat: pdate linear layer forward and backward implementations --- src/kernl/implementations/linear_layer.py | 75 +++++++++++------------ 1 file changed, 36 insertions(+), 39 deletions(-) diff --git a/src/kernl/implementations/linear_layer.py b/src/kernl/implementations/linear_layer.py index ccc391e9..de2bfb02 100644 --- a/src/kernl/implementations/linear_layer.py +++ b/src/kernl/implementations/linear_layer.py @@ -230,7 +230,7 @@ def kernel_fma( ) @triton.heuristics( { - "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, + "K_LOAD_MASK_NEEDED": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, } ) @triton.jit @@ -249,12 +249,12 @@ def kernel_bwd( # The stride variables represent how much to increase the ptr by when moving by 1 # element in a particular dimension. E.g. stride_am is how much to increase a_ptr # by to get the element one row down (A has M rows) - stride_om, - stride_on, - stride_im, - stride_ik, - stride_wn, - stride_wk, + output_m_stride, + output_n_stride, + a_m_stride, + a_k_stride, + w_n_stride, + w_k_stride, # Meta-parameters BLOCK_M: tl.constexpr, GROUP_M: tl.constexpr, @@ -262,50 +262,51 @@ def kernel_bwd( BLOCK_K: tl.constexpr, # split k not used, not performant with activation, kept because early_config_prune is expecting it SPLIT_K: tl.constexpr, - EVEN_K: tl.constexpr, + K_LOAD_MASK_NEEDED: tl.constexpr, ACTIVATION: tl.constexpr, ): - pid = tl.program_id(axis=0) + program_idx = tl.program_id(axis=0) grid_m = (M + BLOCK_M - 1) // BLOCK_M grid_n = (N + BLOCK_N - 1) // BLOCK_N # re-order program ID for better L2 performance width = GROUP_M * grid_n - group_id = pid // width - group_size = min(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (pid % group_size) - pid_n = (pid % width) // (group_size) + group_idx = program_idx // width + group_size = min(grid_m - group_idx * GROUP_M, GROUP_M) + block_m_idx = group_idx * GROUP_M + (program_idx % group_size) + block_n_idx = (program_idx % width) // (group_size) # now compute the block that each program will go through - # rm (resp. rn) denotes a range of indices + # m_offs_untagged (resp. n_offs_untagged) denotes a range of indices # for rows (resp. col) of C - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + m_offs_untagged = block_m_idx * BLOCK_M + tl.arange(0, BLOCK_M) + n_offs_untagged = block_n_idx * BLOCK_N + tl.arange(0, BLOCK_N) + # trick to avoid masking on M and N axis - ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) - rk = tl.arange(0, BLOCK_K) + m_offs = tl.max_contiguous(tl.multiple_of(m_offs_untagged % M, BLOCK_M), BLOCK_M) + n_offs = tl.max_contiguous(tl.multiple_of(n_offs_untagged % N, BLOCK_N), BLOCK_N) + k_range_offs = tl.arange(0, BLOCK_K) - A = A + (ram[:, None] * stride_im + rk[None, :] * stride_ik) - B = B + (rk[:, None] * stride_wk + rbn[None, :] * stride_wn) + A = A + (m_offs[:, None] * a_m_stride + k_range_offs[None, :] * a_k_stride) + B = B + (k_range_offs[:, None] * w_k_stride + n_offs[None, :] * w_n_stride) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(K, 0, -BLOCK_K): - if EVEN_K: + if K_LOAD_MASK_NEEDED: a = tl.load(A) b = tl.load(B) else: - a = tl.load(A, mask=rk[None, :] < k, other=0.0) - b = tl.load(B, mask=rk[:, None] < k, other=0.0) + a = tl.load(A, mask=k_range_offs[None, :] < k, other=0.0) + b = tl.load(B, mask=k_range_offs[:, None] < k, other=0.0) acc += tl.dot(a, b) - A += BLOCK_K * stride_ik - B += BLOCK_K * stride_wk + A += BLOCK_K * a_k_stride + B += BLOCK_K * w_k_stride # optional: fused activation (while the data is in shared memory) if ACTIVATION != "": - act_in_ptrs = ACT_INPUT + ram[:, None] * stride_om + rbn[None, :] * stride_on + act_in_ptrs = ACT_INPUT + m_offs[:, None] * output_m_stride + n_offs[None, :] * output_n_stride act_input = tl.load(act_in_ptrs).to(acc.dtype) if ACTIVATION == "tanh": acc *= activation_func.tanh_grad(act_input) @@ -316,13 +317,9 @@ def kernel_bwd( if ACTIVATION == "relu": acc *= activation_func.relu_grad(act_input) - # rematerialize rm and rn to save registers - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - # write back result - C = C + rm[:, None] * stride_om + rn[None, :] * stride_on - mask = (rm < M)[:, None] & (rn < N)[None, :] + C = C + m_offs_untagged[:, None] * output_m_stride + n_offs_untagged[None, :] * output_n_stride + mask = (m_offs_untagged < M)[:, None] & (n_offs_untagged < N)[None, :] tl.store(C, acc, mask=mask) @@ -449,12 +446,12 @@ def backward( M // 32, # key for triton cache (limit number of compilations) N // 32, K // 32, - stride_om=grad_input.stride(0), # strides - stride_on=grad_input.stride(1), - stride_im=grad_output_reshaped.stride(0), - stride_ik=grad_output_reshaped.stride(1), - stride_wn=weight.stride(1), - stride_wk=weight.stride(0), + output_m_stride=grad_input.stride(0), # strides + output_n_stride=grad_input.stride(1), + a_m_stride=grad_output_reshaped.stride(0), + a_k_stride=grad_output_reshaped.stride(1), + w_n_stride=weight.stride(1), + w_k_stride=weight.stride(0), ACTIVATION=ctx.activation, # optional fused activation GROUP_M=8, # speed optimization: group the programs )