diff --git a/src/kernl/implementations/activation_func.py b/src/kernl/implementations/activation_func.py index 53408a5e..42fbd579 100644 --- a/src/kernl/implementations/activation_func.py +++ b/src/kernl/implementations/activation_func.py @@ -23,6 +23,7 @@ sqrt2pi = math.sqrt(2.0 / math.pi) sqrt2 = math.sqrt(2.0) +gaussian_pdf_normalization = 1.0 / math.sqrt(2 * math.pi) @triton.jit @@ -31,19 +32,49 @@ def tanh(x): return tl.libdevice.tanh(x) +@triton.jit +def tanh_grad(x): + """Tanh derivative function""" + tanh_x = tanh(x) + return 1 - tanh_x * tanh_x + + @triton.jit def relu(x): """Relu activation function""" return tl.maximum(0, x) +@triton.jit +def relu_grad(x): + """Relu derivative function""" + return tl.maximum(0, 1) + + @triton.jit def fast_gelu(x): """Fast approximation of the gelu function. May slightly decrease accuracy.""" return 0.5 * x * (1 + tanh(sqrt2pi * (x + 0.044715 * x * x * x))) +@triton.jit +def fast_gelu_grad(x): + """Derivative of fast approximation of the gelu function.""" + # CREDITS: Fast implementation proposed in + # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30 + tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + return 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) + + @triton.jit def gelu(x): """Gaussian Error Linear Unit (GELU)""" return x * 0.5 * (1.0 + tl.libdevice.erf(x / sqrt2)) + + +@triton.jit +def gelu_grad(x): + """Derivative of Gaussian Error Linear Unit (GELU)""" + cdf = 0.5 * (1.0 + tl.libdevice.erf(x / sqrt2)) + pdf = tl.exp(-0.5 * x * x) * gaussian_pdf_normalization + return cdf + x * pdf diff --git a/src/kernl/implementations/linear_layer.py b/src/kernl/implementations/linear_layer.py index 68ad9144..d7bab453 100644 --- a/src/kernl/implementations/linear_layer.py +++ b/src/kernl/implementations/linear_layer.py @@ -16,13 +16,13 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. -from typing import Optional +from typing import Any, Optional import torch import triton import triton.language as tl from torch.autograd.function import FunctionCtx -from torch.cuda.amp import custom_fwd +from torch.cuda.amp import custom_bwd, custom_fwd from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time from kernl.implementations import activation_func @@ -203,6 +203,127 @@ def kernel_fma( tl.store(C, acc, mask=c_ptr_mask) +@triton.autotune( + configs=[ + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2), + # good for int8 + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2), + ] + + get_configs_io_bound(), + key=["CACHE_KEY_M", "CACHE_KEY_N", "CACHE_KEY_K"], + prune_configs_by={"early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10}, +) +@triton.heuristics( + { + "K_LOAD_MASK_NEEDED": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, + } +) +@triton.jit +def kernel_bwd( + C, # Pointers to matrices + ACT_INPUT, + A, + B, + # Matrix dimensions + M, + N, + K, + CACHE_KEY_M, + CACHE_KEY_N, + CACHE_KEY_K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. stride_am is how much to increase a_ptr + # by to get the element one row down (A has M rows) + output_m_stride, + output_n_stride, + a_m_stride, + a_k_stride, + w_n_stride, + w_k_stride, + # Meta-parameters + BLOCK_M: tl.constexpr, + GROUP_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + # split k not used, not performant with activation, kept because early_config_prune is expecting it + SPLIT_K: tl.constexpr, + K_LOAD_MASK_NEEDED: tl.constexpr, + ACTIVATION: tl.constexpr, +): + program_idx = tl.program_id(axis=0) + + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_idx = program_idx // width + group_size = min(grid_m - group_idx * GROUP_M, GROUP_M) + block_m_idx = group_idx * GROUP_M + (program_idx % group_size) + block_n_idx = (program_idx % width) // (group_size) + + # now compute the block that each program will go through + # m_offs_untagged (resp. n_offs_untagged) denotes a range of indices + # for rows (resp. col) of C + m_offs_untagged = block_m_idx * BLOCK_M + tl.arange(0, BLOCK_M) + n_offs_untagged = block_n_idx * BLOCK_N + tl.arange(0, BLOCK_N) + + # trick to avoid masking on M and N axis + m_offs = tl.max_contiguous(tl.multiple_of(m_offs_untagged % M, BLOCK_M), BLOCK_M) + n_offs = tl.max_contiguous(tl.multiple_of(n_offs_untagged % N, BLOCK_N), BLOCK_N) + k_range_offs = tl.arange(0, BLOCK_K) + + A = A + (m_offs[:, None] * a_m_stride + k_range_offs[None, :] * a_k_stride) + B = B + (k_range_offs[:, None] * w_k_stride + n_offs[None, :] * w_n_stride) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + for k in range(K, 0, -BLOCK_K): + if K_LOAD_MASK_NEEDED: + a = tl.load(A) + b = tl.load(B) + else: + a = tl.load(A, mask=k_range_offs[None, :] < k, other=0.0) + b = tl.load(B, mask=k_range_offs[:, None] < k, other=0.0) + acc += tl.dot(a, b) + + A += BLOCK_K * a_k_stride + B += BLOCK_K * w_k_stride + + # optional: fused activation (while the data is in shared memory) + if ACTIVATION != "": + act_in_ptrs = ACT_INPUT + m_offs[:, None] * output_m_stride + n_offs[None, :] * output_n_stride + act_input = tl.load(act_in_ptrs).to(acc.dtype) + if ACTIVATION == "tanh": + acc *= activation_func.tanh_grad(act_input) + if ACTIVATION == "gelu": + acc *= activation_func.gelu_grad(act_input) + if ACTIVATION == "fast_gelu": + acc *= activation_func.fast_gelu_grad(act_input) + if ACTIVATION == "relu": + acc *= activation_func.relu_grad(act_input) + + # write back result + C = C + m_offs_untagged[:, None] * output_m_stride + n_offs_untagged[None, :] * output_n_stride + mask = (m_offs_untagged < M)[:, None] & (n_offs_untagged < N)[None, :] + tl.store(C, acc, mask=mask) + + class LinearLayer(torch.autograd.Function): @staticmethod @custom_fwd(cast_inputs=torch.float16) @@ -236,6 +357,7 @@ def forward( assert bias is None or bias.shape[0] == weight.shape[0], "Incompatible dimensions in between weight and bias" assert weight.is_contiguous() + ctx.activation = activation M, K = x_.shape N, K = weight.shape @@ -274,6 +396,70 @@ def forward( ctx.save_for_backward(weight, bias, x) return outputs + @staticmethod + @custom_bwd + def backward( + ctx: FunctionCtx, + *grad_outputs: Any, + ) -> torch.Tensor: + """ + Compute e = activation(grad_output @ weight + bias). + This wrapper kicks the `kernel_fwd` Triton kernel + :param ctx: context for autograd + :param grad_outputs: input tensor + :return: result tensor + """ + weight, bias, act_inputs = ctx.saved_tensors + grad_outputs = grad_outputs[0] + batch_shape, n = grad_outputs.shape[:-1], grad_outputs.shape[-1] + batch_dim = batch_shape.numel() + grad_output_reshaped = grad_outputs.reshape(batch_dim, n) + + if grad_output_reshaped.stride(0) > 1 and grad_output_reshaped.stride(1) > 1: + grad_output_reshaped = grad_output_reshaped.contiguous() + if weight.stride(0) > 1 and weight.stride(1) > 1: + weight = weight.contiguous() + + assert ( + grad_outputs.dtype == weight.dtype + ), f"grad_output and weight must have the same dtype, got {grad_outputs.dtype} and {weight.dtype}" + assert ( + grad_output_reshaped.shape[1] == weight.shape[0] + ), f"Incompatible dimensions: {grad_output_reshaped.shape} - {weight.shape}" + if ctx.activation != "id": + assert act_inputs is not None, f"act_input is required for activation {ctx.activation}" + + # M, N, K in bwd are different from M, N, K in fwd + M, K = grad_output_reshaped.shape + K, N = weight.shape + + grad_input = torch.empty((M, N), device=grad_outputs.device, dtype=grad_outputs.dtype) + + # 1D launch kernel where each block gets its own program. + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) # noqa + + kernel_bwd[grid]( + grad_input, + act_inputs, + grad_output_reshaped, + weight, # data ptrs + M, # shapes + N, + K, + M // 32, # key for triton cache (limit number of compilations) + N // 32, + K // 32, + output_m_stride=grad_input.stride(0), # strides + output_n_stride=grad_input.stride(1), + a_m_stride=grad_output_reshaped.stride(0), + a_k_stride=grad_output_reshaped.stride(1), + w_n_stride=weight.stride(1), + w_k_stride=weight.stride(0), + ACTIVATION=ctx.activation, # optional fused activation + GROUP_M=8, # speed optimization: group the programs + ) + return grad_input.reshape(*batch_shape, grad_input.shape[-1]), None, None, None, None + def linear_layer( x: torch.Tensor, diff --git a/test/test_linear_layer.py b/test/test_linear_layer.py index ba777d2d..ede66319 100644 --- a/test/test_linear_layer.py +++ b/test/test_linear_layer.py @@ -17,6 +17,8 @@ import pytest import torch +from torch.cuda.amp import autocast +from torch.nn import MSELoss from conftest import assert_all_close, set_seed @@ -37,7 +39,7 @@ def get_pytorch_activation(activation: str) -> Callable: raise ValueError(f"Unknown activation: {activation}") -implementations = { +forward_implementations = { "pytorch": lambda weight, bias, activation: lambda x: get_pytorch_activation(activation)( torch.nn.functional.linear(x, weight, bias) ), @@ -57,7 +59,7 @@ def get_pytorch_activation(activation: str) -> Callable: @pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["fp32", "fp16"]) @pytest.mark.parametrize("cuda_graphs", [False, True], ids=["no_cuda_graphs", "cuda_graphs"]) @pytest.mark.parametrize("implementation", ["triton", "pytorch"]) -def test_benchmark( +def test_benchmark_linear_forward( benchmark, implementation: str, cuda_graphs: bool, @@ -88,7 +90,7 @@ def test_benchmark( layer_bias = layer_bias.to(dtype=dtype) x = x.to(dtype=dtype) - fn = implementations[implementation](layer_weight, layer_bias, activation) + fn = forward_implementations[implementation](layer_weight, layer_bias, activation) if cuda_graphs: run = cuda_graphs_wrapper(model=fn, inputs=[x]) # CUDA graphs wraps output in a tuple @@ -97,3 +99,73 @@ def test_benchmark( value = benchmark(fn, x) assert_all_close(expected, value.float(), rtol=1e-1, atol=1e-1) + + +backward_implementations = { + "pytorch": lambda weight, bias, activation, random_output: lambda x: MSELoss(reduction="sum")( + get_pytorch_activation(activation)(torch.nn.functional.linear(x, weight, bias)), random_output + ).backward(), + "triton": lambda weight, bias, activation, random_output: lambda x: MSELoss(reduction="sum")( + linear_layer(x, weight, bias, activation, None), random_output + ).backward(), +} + + +@set_seed() +@pytest.mark.parametrize("contiguous", [True, False], ids=["contiguous", "non-contiguous"]) +@pytest.mark.parametrize("activation", ["", "tanh", "gelu", "relu"], ids=["no_activation", "tanh", "gelu", "relu"]) +@pytest.mark.parametrize( + "shape", + [(1, 8, 8, 8)] + [(bs, M, 768, 768) for bs in [1, 16] for M in [8, 16, 128, 256, 512]], + ids=lambda s: "x".join(map(str, s)), +) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["fp32", "fp16"]) +@pytest.mark.parametrize("cuda_graphs", [False], ids=["no_cuda_graphs"]) +@pytest.mark.parametrize("implementation", ["triton", "pytorch"]) +def test_benchmark_linear_backward( + benchmark, + implementation: str, + cuda_graphs: bool, + shape: Tuple[int, int, int, int], + dtype: torch.dtype, + activation: str, + contiguous: bool, +): + batch, M, N, K = shape + factory_kwargs = {"device": "cuda", "dtype": torch.float32, "requires_grad": True} + # order of dimensions is wrong so we force contiguous call + x = torch.randn((batch, K, M), **factory_kwargs) + x = x.mT + if contiguous: + x = x.contiguous() + else: + assert not x.is_contiguous() + x = x.to(dtype=dtype) + x.retain_grad() # force saving grad + + layer_weight = torch.randn((N, K), **factory_kwargs) + layer_weight = layer_weight.to(dtype=dtype) + layer_bias = torch.randn((K,), **factory_kwargs) + layer_bias = layer_bias.to(dtype=dtype) + + x_triton = torch.clone(x) + x_triton = x_triton.to(dtype=dtype) + assert x_triton.is_contiguous() == x.is_contiguous() + x_triton.retain_grad() + + pytorch_layer_activation = get_pytorch_activation(activation) + pytorch_fwd_output = pytorch_layer_activation(torch.nn.functional.linear(x, layer_weight, layer_bias)) + random_output = torch.randn(pytorch_fwd_output.shape, **factory_kwargs) + random_output = random_output.to(dtype=dtype) + loss = MSELoss() + + with autocast(dtype=dtype): + loss(pytorch_fwd_output, random_output).backward() + fn = backward_implementations[implementation](layer_weight, layer_bias, activation, random_output) + if cuda_graphs: + run = cuda_graphs_wrapper(model=fn, inputs=[x_triton]) + # CUDA graphs wraps output in a tuple + fn = lambda tensor: run([tensor])[0] # noqa: E731 + + _ = benchmark(fn, x_triton) + assert_all_close(x.grad, x_triton.grad, rtol=1e-1, atol=1e-1)