Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
2797d13
feat: add linear layer backward + activations grad
ayoub-louati Jan 5, 2023
d332ade
feat: add linear backwar wrapper +renaming + code formatting
ayoub-louati Jan 5, 2023
5a6a487
feat: fix code formatting
ayoub-louati Jan 5, 2023
5a4eba6
feat: add fast gelu grad
ayoub-louati Jan 5, 2023
1587cff
feat: fix linear layer (forward and backward)
ayoub-louati Jan 10, 2023
1cbfadf
feat: fix backward kernel call
ayoub-louati Jan 10, 2023
c4fb771
feat: add backward benchmark
ayoub-louati Jan 11, 2023
ed5dc33
feat: define linear_layer class with forward and backward functions
ayoub-louati Jan 11, 2023
738c70f
feat: use FunctionCtx to pass tensors from forward to backward function
ayoub-louati Jan 11, 2023
73cc6e2
feat: fix forward benchmark
ayoub-louati Jan 11, 2023
74e3643
feat: fix benchmark backward
ayoub-louati Jan 13, 2023
e204789
Merge branch 'main' into feat/add_linear_layer_train_support
ayoub-louati Jan 13, 2023
c17002f
feat: fix benchmark backward
ayoub-louati Jan 13, 2023
cad7341
Merge branch 'main' into feat/add_linear_layer_train_support
ayoub-louati Jan 13, 2023
ceb40be
feat: add backward benchmark
ayoub-louati Jan 17, 2023
b29b9ed
feat: add missing tests
ayoub-louati Jan 17, 2023
0103ec7
feat: remove debugging print in backward function + fix fp16 tests wi…
ayoub-louati Jan 18, 2023
ccc7a5f
feat: fix relu grad + fix benchmark tests
ayoub-louati Jan 20, 2023
b646a9a
feat: fix gelu grad function
ayoub-louati Jan 23, 2023
d162d27
feat: simplify activation function
ayoub-louati Jan 23, 2023
2d7e501
feat: fix stride -> fix unit tests
ayoub-louati Jan 23, 2023
2d46917
feat: fix tanh grad compilation
ayoub-louati Jan 23, 2023
b7228e9
feat: remove unnecessary modifications
ayoub-louati Jan 23, 2023
620be3a
feat: remove unused import
ayoub-louati Jan 24, 2023
0e8bbaa
Merge branch 'main' into feat/add_linear_layer_train_support
ayoub-louati Jan 27, 2023
5273fab
feat: pdate linear layer forward and backward implementations
ayoub-louati Jan 27, 2023
be678ec
Merge branch 'main' into feat/add_linear_layer_train_support
ayoub-louati Feb 1, 2023
1cc4387
Merge branch 'main' into feat/add_linear_layer_train_support
ayoub-louati Feb 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions src/kernl/implementations/activation_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

sqrt2pi = math.sqrt(2.0 / math.pi)
sqrt2 = math.sqrt(2.0)
gaussian_pdf_normalization = 1.0 / math.sqrt(2 * math.pi)


@triton.jit
Expand All @@ -31,19 +32,49 @@ def tanh(x):
return tl.libdevice.tanh(x)


@triton.jit
def tanh_grad(x):
"""Tanh derivative function"""
tanh_x = tanh(x)
return 1 - tanh_x * tanh_x


@triton.jit
def relu(x):
"""Relu activation function"""
return tl.maximum(0, x)


@triton.jit
def relu_grad(x):
"""Relu derivative function"""
return tl.maximum(0, 1)


@triton.jit
def fast_gelu(x):
"""Fast approximation of the gelu function. May slightly decrease accuracy."""
return 0.5 * x * (1 + tanh(sqrt2pi * (x + 0.044715 * x * x * x)))


@triton.jit
def fast_gelu_grad(x):
"""Derivative of fast approximation of the gelu function."""
# CREDITS: Fast implementation proposed in
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30
tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x))
return 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)


@triton.jit
def gelu(x):
"""Gaussian Error Linear Unit (GELU)"""
return x * 0.5 * (1.0 + tl.libdevice.erf(x / sqrt2))


@triton.jit
def gelu_grad(x):
"""Derivative of Gaussian Error Linear Unit (GELU)"""
cdf = 0.5 * (1.0 + tl.libdevice.erf(x / sqrt2))
pdf = tl.exp(-0.5 * x * x) * gaussian_pdf_normalization
return cdf + x * pdf
190 changes: 188 additions & 2 deletions src/kernl/implementations/linear_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional
from typing import Any, Optional

import torch
import triton
import triton.language as tl
from torch.autograd.function import FunctionCtx
from torch.cuda.amp import custom_fwd
from torch.cuda.amp import custom_bwd, custom_fwd
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time

from kernl.implementations import activation_func
Expand Down Expand Up @@ -203,6 +203,127 @@ def kernel_fma(
tl.store(C, acc, mask=c_ptr_mask)


@triton.autotune(
configs=[
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2),
# good for int8
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2),
]
+ get_configs_io_bound(),
key=["CACHE_KEY_M", "CACHE_KEY_N", "CACHE_KEY_K"],
prune_configs_by={"early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10},
)
@triton.heuristics(
{
"K_LOAD_MASK_NEEDED": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
}
)
@triton.jit
def kernel_bwd(
C, # Pointers to matrices
ACT_INPUT,
A,
B,
# Matrix dimensions
M,
N,
K,
CACHE_KEY_M,
CACHE_KEY_N,
CACHE_KEY_K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. stride_am is how much to increase a_ptr
# by to get the element one row down (A has M rows)
output_m_stride,
output_n_stride,
a_m_stride,
a_k_stride,
w_n_stride,
w_k_stride,
# Meta-parameters
BLOCK_M: tl.constexpr,
GROUP_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
# split k not used, not performant with activation, kept because early_config_prune is expecting it
SPLIT_K: tl.constexpr,
K_LOAD_MASK_NEEDED: tl.constexpr,
ACTIVATION: tl.constexpr,
):
program_idx = tl.program_id(axis=0)

grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_idx = program_idx // width
group_size = min(grid_m - group_idx * GROUP_M, GROUP_M)
block_m_idx = group_idx * GROUP_M + (program_idx % group_size)
block_n_idx = (program_idx % width) // (group_size)

# now compute the block that each program will go through
# m_offs_untagged (resp. n_offs_untagged) denotes a range of indices
# for rows (resp. col) of C
m_offs_untagged = block_m_idx * BLOCK_M + tl.arange(0, BLOCK_M)
n_offs_untagged = block_n_idx * BLOCK_N + tl.arange(0, BLOCK_N)

# trick to avoid masking on M and N axis
m_offs = tl.max_contiguous(tl.multiple_of(m_offs_untagged % M, BLOCK_M), BLOCK_M)
n_offs = tl.max_contiguous(tl.multiple_of(n_offs_untagged % N, BLOCK_N), BLOCK_N)
k_range_offs = tl.arange(0, BLOCK_K)

A = A + (m_offs[:, None] * a_m_stride + k_range_offs[None, :] * a_k_stride)
B = B + (k_range_offs[:, None] * w_k_stride + n_offs[None, :] * w_n_stride)

acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

for k in range(K, 0, -BLOCK_K):
if K_LOAD_MASK_NEEDED:
a = tl.load(A)
b = tl.load(B)
else:
a = tl.load(A, mask=k_range_offs[None, :] < k, other=0.0)
b = tl.load(B, mask=k_range_offs[:, None] < k, other=0.0)
acc += tl.dot(a, b)

A += BLOCK_K * a_k_stride
B += BLOCK_K * w_k_stride

# optional: fused activation (while the data is in shared memory)
if ACTIVATION != "":
act_in_ptrs = ACT_INPUT + m_offs[:, None] * output_m_stride + n_offs[None, :] * output_n_stride
act_input = tl.load(act_in_ptrs).to(acc.dtype)
if ACTIVATION == "tanh":
acc *= activation_func.tanh_grad(act_input)
if ACTIVATION == "gelu":
acc *= activation_func.gelu_grad(act_input)
if ACTIVATION == "fast_gelu":
acc *= activation_func.fast_gelu_grad(act_input)
if ACTIVATION == "relu":
acc *= activation_func.relu_grad(act_input)

# write back result
C = C + m_offs_untagged[:, None] * output_m_stride + n_offs_untagged[None, :] * output_n_stride
mask = (m_offs_untagged < M)[:, None] & (n_offs_untagged < N)[None, :]
tl.store(C, acc, mask=mask)


class LinearLayer(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
Expand Down Expand Up @@ -236,6 +357,7 @@ def forward(
assert bias is None or bias.shape[0] == weight.shape[0], "Incompatible dimensions in between weight and bias"
assert weight.is_contiguous()

ctx.activation = activation
M, K = x_.shape
N, K = weight.shape

Expand Down Expand Up @@ -274,6 +396,70 @@ def forward(
ctx.save_for_backward(weight, bias, x)
return outputs

@staticmethod
@custom_bwd
def backward(
ctx: FunctionCtx,
*grad_outputs: Any,
) -> torch.Tensor:
"""
Compute e = activation(grad_output @ weight + bias).
This wrapper kicks the `kernel_fwd` Triton kernel
:param ctx: context for autograd
:param grad_outputs: input tensor
:return: result tensor
"""
weight, bias, act_inputs = ctx.saved_tensors
grad_outputs = grad_outputs[0]
batch_shape, n = grad_outputs.shape[:-1], grad_outputs.shape[-1]
batch_dim = batch_shape.numel()
grad_output_reshaped = grad_outputs.reshape(batch_dim, n)

if grad_output_reshaped.stride(0) > 1 and grad_output_reshaped.stride(1) > 1:
grad_output_reshaped = grad_output_reshaped.contiguous()
if weight.stride(0) > 1 and weight.stride(1) > 1:
weight = weight.contiguous()

assert (
grad_outputs.dtype == weight.dtype
), f"grad_output and weight must have the same dtype, got {grad_outputs.dtype} and {weight.dtype}"
assert (
grad_output_reshaped.shape[1] == weight.shape[0]
), f"Incompatible dimensions: {grad_output_reshaped.shape} - {weight.shape}"
if ctx.activation != "id":
assert act_inputs is not None, f"act_input is required for activation {ctx.activation}"

# M, N, K in bwd are different from M, N, K in fwd
M, K = grad_output_reshaped.shape
K, N = weight.shape

grad_input = torch.empty((M, N), device=grad_outputs.device, dtype=grad_outputs.dtype)

# 1D launch kernel where each block gets its own program.
grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) # noqa

kernel_bwd[grid](
grad_input,
act_inputs,
grad_output_reshaped,
weight, # data ptrs
M, # shapes
N,
K,
M // 32, # key for triton cache (limit number of compilations)
N // 32,
K // 32,
output_m_stride=grad_input.stride(0), # strides
output_n_stride=grad_input.stride(1),
a_m_stride=grad_output_reshaped.stride(0),
a_k_stride=grad_output_reshaped.stride(1),
w_n_stride=weight.stride(1),
w_k_stride=weight.stride(0),
ACTIVATION=ctx.activation, # optional fused activation
GROUP_M=8, # speed optimization: group the programs
)
return grad_input.reshape(*batch_shape, grad_input.shape[-1]), None, None, None, None


def linear_layer(
x: torch.Tensor,
Expand Down
78 changes: 75 additions & 3 deletions test/test_linear_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

import pytest
import torch
from torch.cuda.amp import autocast
from torch.nn import MSELoss

from conftest import assert_all_close, set_seed

Expand All @@ -37,7 +39,7 @@ def get_pytorch_activation(activation: str) -> Callable:
raise ValueError(f"Unknown activation: {activation}")


implementations = {
forward_implementations = {
"pytorch": lambda weight, bias, activation: lambda x: get_pytorch_activation(activation)(
torch.nn.functional.linear(x, weight, bias)
),
Expand All @@ -57,7 +59,7 @@ def get_pytorch_activation(activation: str) -> Callable:
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["fp32", "fp16"])
@pytest.mark.parametrize("cuda_graphs", [False, True], ids=["no_cuda_graphs", "cuda_graphs"])
@pytest.mark.parametrize("implementation", ["triton", "pytorch"])
def test_benchmark(
def test_benchmark_linear_forward(
benchmark,
implementation: str,
cuda_graphs: bool,
Expand Down Expand Up @@ -88,7 +90,7 @@ def test_benchmark(
layer_bias = layer_bias.to(dtype=dtype)
x = x.to(dtype=dtype)

fn = implementations[implementation](layer_weight, layer_bias, activation)
fn = forward_implementations[implementation](layer_weight, layer_bias, activation)
if cuda_graphs:
run = cuda_graphs_wrapper(model=fn, inputs=[x])
# CUDA graphs wraps output in a tuple
Expand All @@ -97,3 +99,73 @@ def test_benchmark(
value = benchmark(fn, x)

assert_all_close(expected, value.float(), rtol=1e-1, atol=1e-1)


backward_implementations = {
"pytorch": lambda weight, bias, activation, random_output: lambda x: MSELoss(reduction="sum")(
get_pytorch_activation(activation)(torch.nn.functional.linear(x, weight, bias)), random_output
).backward(),
"triton": lambda weight, bias, activation, random_output: lambda x: MSELoss(reduction="sum")(
linear_layer(x, weight, bias, activation, None), random_output
).backward(),
}


@set_seed()
@pytest.mark.parametrize("contiguous", [True, False], ids=["contiguous", "non-contiguous"])
@pytest.mark.parametrize("activation", ["", "tanh", "gelu", "relu"], ids=["no_activation", "tanh", "gelu", "relu"])
@pytest.mark.parametrize(
"shape",
[(1, 8, 8, 8)] + [(bs, M, 768, 768) for bs in [1, 16] for M in [8, 16, 128, 256, 512]],
ids=lambda s: "x".join(map(str, s)),
)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["fp32", "fp16"])
@pytest.mark.parametrize("cuda_graphs", [False], ids=["no_cuda_graphs"])
@pytest.mark.parametrize("implementation", ["triton", "pytorch"])
def test_benchmark_linear_backward(
benchmark,
implementation: str,
cuda_graphs: bool,
shape: Tuple[int, int, int, int],
dtype: torch.dtype,
activation: str,
contiguous: bool,
):
batch, M, N, K = shape
factory_kwargs = {"device": "cuda", "dtype": torch.float32, "requires_grad": True}
# order of dimensions is wrong so we force contiguous call
x = torch.randn((batch, K, M), **factory_kwargs)
x = x.mT
if contiguous:
x = x.contiguous()
else:
assert not x.is_contiguous()
x = x.to(dtype=dtype)
x.retain_grad() # force saving grad

layer_weight = torch.randn((N, K), **factory_kwargs)
layer_weight = layer_weight.to(dtype=dtype)
layer_bias = torch.randn((K,), **factory_kwargs)
layer_bias = layer_bias.to(dtype=dtype)

x_triton = torch.clone(x)
x_triton = x_triton.to(dtype=dtype)
assert x_triton.is_contiguous() == x.is_contiguous()
x_triton.retain_grad()

pytorch_layer_activation = get_pytorch_activation(activation)
pytorch_fwd_output = pytorch_layer_activation(torch.nn.functional.linear(x, layer_weight, layer_bias))
random_output = torch.randn(pytorch_fwd_output.shape, **factory_kwargs)
random_output = random_output.to(dtype=dtype)
loss = MSELoss()

with autocast(dtype=dtype):
loss(pytorch_fwd_output, random_output).backward()
fn = backward_implementations[implementation](layer_weight, layer_bias, activation, random_output)
if cuda_graphs:
run = cuda_graphs_wrapper(model=fn, inputs=[x_triton])
# CUDA graphs wraps output in a tuple
fn = lambda tensor: run([tensor])[0] # noqa: E731

_ = benchmark(fn, x_triton)
assert_all_close(x.grad, x_triton.grad, rtol=1e-1, atol=1e-1)