Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ plugins:
python:
options:
show_source: false
import:
- url: https://pytorch.org/docs/1.13/objects.inv
enable_inventory: false
watch:
- src/kernl

Expand All @@ -174,6 +177,7 @@ nav:
- Code Reference:
- Model Optimization: reference/model_optimization.md
- Graph Pattern Matching: reference/extended_matcher.md
- Autocast: reference/autocast.md
- Contribution guide:
- Contributing: contribution-guide/contributing.md
- Code of conduct: contribution-guide/code-of-conduct.md
Expand Down
158 changes: 158 additions & 0 deletions src/kernl/autocast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import torch
import functools

from typing import Any

from kernl.utils.cast import _cast


_is_autocast_enabled: bool = True
_autocast_dtype: torch.dtype = torch.float16
_is_torch_autocast_dtype_used: bool = True


def enable_autocast(is_autocast_enabled: bool):
"""Change autocast state for kernl.
Args:
is_autocast_enabled: True to enable autocast, False to disable
"""
global _is_autocast_enabled
_is_autocast_enabled = is_autocast_enabled


def is_autocast_enabled() -> bool:
"""Check is autocast enabled.
Returns:
wether autocast is enabled
"""
return _is_autocast_enabled


def set_autocast_dtype(dtype: torch.dtype) -> None:
"""Change autocast target dtype.
Args:
dtype: which dtype to autocast on
"""
global _autocast_dtype
_autocast_dtype = dtype


def get_autocast_dtype() -> torch.dtype:
"""Get autocast target dtype.
Returns:
autocast target dype
"""
return _autocast_dtype


def use_torch_autocast_dtype(use_torch_autocast_dtype: bool) -> None:
""" Check i
Args:
use_torch_autocast_dtype: wether torch autocast dtype is used
"""
global _is_torch_autocast_dtype_used
_is_torch_autocast_dtype_used = use_torch_autocast_dtype


def is_torch_autocast_dtype_used() -> bool:
""" Check if torch autocast dtype is used in autocast
Returns:
wether torch autocast dtype is used
"""
return _is_torch_autocast_dtype_used


class autocast(object):
"""Serve as context managers or decorators that allow fused kernels to run in mixed precision.

In these regions, kernels run in an op-specific dtype chosen by autocast
to improve performance while maintaining accuracy.

When entering an autocast-enabled region, Tensors may be any type.
You should not call `half()` or `bfloat16()` on your model(s) or inputs when using autocasting.

`autocast` should wrap only the forward pass(es) of your network, including the loss
computation(s). Backward passes under autocast are not recommended.
Backward ops run in the same type that autocast used for corresponding forward ops.

Heavily inspired by [torch.autocast][].
Args:
enabled: Whether autocasting should be enabled in the region.
dtype: Whether to use torch.float16 or torch.bfloat16.
use_torch_autocast_dtype: Whether use torch autocast dtype.
"""
def __init__(
self, enabled: bool = True, dtype: torch.dtype = torch.float16, use_torch_autocast_dtype: bool = False
):
self._enabled = enabled
self._dtype = dtype
self._use_autocast_dtype = use_torch_autocast_dtype

def __enter__(self):
self._prev = is_autocast_enabled()
self._prev_dtype = get_autocast_dtype()
self._prev_use_torch_autocast_dtype = is_torch_autocast_dtype_used()
enable_autocast(self._enabled)
set_autocast_dtype(self._dtype)
use_torch_autocast_dtype(self._use_autocast_dtype)

def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):
enable_autocast(self._prev)
set_autocast_dtype(self._prev_dtype)
use_torch_autocast_dtype(self._prev_use_torch_autocast_dtype)


def custom_fwd(fwd=None, **kwargs):
"""
Helper decorator for ``forward`` methods of custom autograd functions (subclasses of
:class:`torch.autograd.Function`). See the :ref:`example page<amp-custom-examples>` for more detail.

Args:
cast_inputs (:class:`torch.dtype` or None, optional, default=None): If not ``None``,
when ``forward`` runs in an autocast-enabled region, casts incoming
floating-point CUDA Tensors to the target dtype (non-floating-point Tensors are not affected),
then executes ``forward`` with autocast disabled.
If ``None``, ``forward``'s internal ops execute with the current autocast state.

.. note::
If the decorated ``forward`` is called outside an autocast-enabled region,
:func:`custom_fwd<custom_fwd>` is a no-op and ``cast_inputs`` has no effect.
"""
if fwd is None:
if len(kwargs) == 0:
cast_inputs = None
else:
assert len(kwargs) == 1
cast_inputs = kwargs["cast_inputs"]

return functools.partial(custom_fwd, cast_inputs=cast_inputs)

if len(kwargs) == 0:
cast_inputs = None
else:
assert len(kwargs) == 1
cast_inputs = kwargs["cast_inputs"]

@functools.wraps(fwd)
def decorate_fwd(*args, **kwargs):
inputs = cast_inputs
if inputs is None:
if is_autocast_enabled():
if is_torch_autocast_dtype_used():
if torch.is_autocast_enabled():
inputs = torch.get_autocast_gpu_dtype()
else:
inputs = get_autocast_dtype()
if inputs is None:
args[0]._fwd_used_autocast = torch.is_autocast_enabled()
return fwd(*args, **kwargs)
else:
autocast_context = torch.is_autocast_enabled() or is_autocast_enabled()
args[0]._fwd_used_autocast = torch.is_autocast_enabled()
if autocast_context:
with torch.cuda.amp.autocast(enabled=False):
return fwd(*_cast(args, inputs), **_cast(kwargs, inputs))
else:
return fwd(*args, **kwargs)

return decorate_fwd
5 changes: 3 additions & 2 deletions src/kernl/implementations/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
import triton
import triton.language as tl
from torch.autograd.function import FunctionCtx
from torch.cuda.amp import custom_fwd

from kernl.autocast import custom_fwd


# CREDITS: Initially inspired by the Triton tutorial
Expand Down Expand Up @@ -448,7 +449,7 @@ def _fwd_kernel(

class Attention(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
@custom_fwd()
def forward(
ctx: FunctionCtx,
q: torch.Tensor,
Expand Down
3 changes: 2 additions & 1 deletion src/kernl/implementations/attention_skinny.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
import triton
import triton.language as tl
from torch.autograd.function import FunctionCtx
from torch.cuda.amp import custom_fwd

from kernl.autocast import custom_fwd


def attention_split_1_reference(
Expand Down
5 changes: 3 additions & 2 deletions src/kernl/implementations/attention_vec_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import triton
import triton.language as tl
from torch.autograd.function import FunctionCtx
from torch.cuda.amp import custom_fwd

from kernl.autocast import custom_fwd


@triton.autotune(
Expand Down Expand Up @@ -152,7 +153,7 @@ def grid(args) -> Tuple[int, int, int]:

class AttentionVecMat(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
@custom_fwd()
def forward(
ctx: FunctionCtx,
q: torch.Tensor,
Expand Down
5 changes: 3 additions & 2 deletions src/kernl/implementations/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
import triton
import triton.language as tl
from torch.autograd.function import FunctionCtx
from torch.cuda.amp import custom_fwd
from triton import JITFunction

from kernl.autocast import custom_fwd


# CREDITS: Initially inspired by the Triton tutorial and xformers implementation

Expand Down Expand Up @@ -278,7 +279,7 @@ def _layer_norm_fwd_fused_multi_pass(

class LayerNorm(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
@custom_fwd()
def forward(
ctx: FunctionCtx,
x: torch.Tensor,
Expand Down
4 changes: 2 additions & 2 deletions src/kernl/implementations/linear_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
import triton
import triton.language as tl
from torch.autograd.function import FunctionCtx
from torch.cuda.amp import custom_fwd
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time

from kernl.autocast import custom_fwd
from kernl.implementations import activation_func


Expand Down Expand Up @@ -205,7 +205,7 @@ def kernel_fma(

class LinearLayer(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
@custom_fwd()
def forward(
ctx: FunctionCtx,
x: torch.Tensor,
Expand Down
67 changes: 67 additions & 0 deletions test/test_autocast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import pytest
import torch

from kernl.autocast import autocast, custom_fwd, get_autocast_dtype, is_autocast_enabled, is_torch_autocast_dtype_used


def test_autocast():

with autocast():
assert is_autocast_enabled() is True

with autocast(enabled=False):
assert is_autocast_enabled() is False
assert is_autocast_enabled() is True

with autocast(dtype=torch.float32):
assert get_autocast_dtype() is torch.float32
assert get_autocast_dtype() is torch.float16

with autocast(use_torch_autocast_dtype=False):
assert is_torch_autocast_dtype_used() is False
assert is_torch_autocast_dtype_used() is True


@pytest.mark.parametrize(
"""
torch_autocast_enabled, kernl_autocast_enabled, torch_autocast_dtype,
kernl_autocast_dtype, use_torch_autocast_dtype, cast_inputs, expected_dtype""",
[
# we first check if @custom_fwd behaves the same as torch.cuda.amp.custom_fwd
# when we're not using kernl autocast
(False, False, None, None, False, None, torch.float32),
(False, False, None, None, False, torch.bfloat16, torch.float32),
(True, False, torch.float16, None, False, None, torch.float32),
(True, False, torch.float16, None, False, torch.bfloat16, torch.bfloat16),
# we check @custom_fwd without torch autocast
(False, True, None, torch.float16, False, None, torch.float16),
(False, True, None, torch.float16, False, torch.bfloat16, torch.bfloat16),
(False, True, None, torch.float16, True, None, torch.float32),
(False, True, None, torch.float16, True, torch.bfloat16, torch.bfloat16),
# we check @custom_fwd within torch autocast context and kernl autocast context
(True, True, torch.float16, torch.bfloat16, False, None, torch.bfloat16),
(True, True, torch.float16, torch.float16, False, torch.bfloat16, torch.bfloat16),
(True, True, torch.float16, torch.bfloat16, False, None, torch.bfloat16),
(True, True, torch.float16, torch.bfloat16, True, None, torch.float16),
(True, True, torch.float16, torch.float16, True, torch.bfloat16, torch.bfloat16),
],
)
def test_custom_fwd(
torch_autocast_enabled,
kernl_autocast_enabled,
torch_autocast_dtype,
kernl_autocast_dtype,
use_torch_autocast_dtype,
cast_inputs,
expected_dtype,
):
@custom_fwd(cast_inputs=cast_inputs)
def fwd(inputs: torch.Tensor):
assert inputs.dtype == expected_dtype

with torch.cuda.amp.autocast(enabled=torch_autocast_enabled, dtype=torch_autocast_dtype), autocast(
enabled=kernl_autocast_enabled, dtype=kernl_autocast_dtype, use_torch_autocast_dtype=use_torch_autocast_dtype
):
inputs = torch.empty((100), dtype=torch.float32, device="cuda")
fwd(inputs)
print
11 changes: 7 additions & 4 deletions test/test_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import torch

from conftest import assert_all_close, set_seed

from kernl.implementations.layer_norm import (
_layer_norm_fwd_fused_multi_pass,
_layer_norm_fwd_fused_single_pass,
Expand Down Expand Up @@ -47,16 +46,20 @@
@set_seed()
@pytest.mark.parametrize("shape", [128, 512, 1024, 2048, 4096], ids=lambda x: f"shape={x}x{x}")
@pytest.mark.parametrize("cuda_graphs", [True, False], ids=["cuda_graphs", "no_cuda_graphs"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=["fp16", "bfloat16", "fp32"])
@pytest.mark.parametrize("implementation", implementations_layer_norm.keys())
def test_benchmark_layer_norm(benchmark, shape: int, dtype, cuda_graphs: bool, implementation: str):
if dtype == torch.bfloat16 and implementation == "pytorch_naive":
pytest.skip("unsupported configuration")
M = N = shape
eps = 1e-5
factory_kwargs = {"device": "cuda", "dtype": torch.float32, "requires_grad": False}
factory_kwargs = {"device": "cuda", "dtype": torch.bfloat16 if dtype == torch.bfloat16 else torch.float32, "requires_grad": False}
layer_weight = torch.rand((N,), **factory_kwargs)
layer_bias = torch.randn_like(layer_weight)
x = -20 + 0.5 * torch.randn((M, N), **factory_kwargs)
expected = torch.nn.functional.layer_norm(x, layer_weight.shape, layer_weight, layer_bias, eps)
if dtype == torch.bfloat16:
expected = expected.float()

# tensors casting
layer_weight = layer_weight.to(dtype)
Expand All @@ -82,7 +85,7 @@ def test_benchmark_layer_norm(benchmark, shape: int, dtype, cuda_graphs: bool, i


@pytest.mark.parametrize("cuda_graphs", [True, False], ids=["cuda_graphs", "no_cuda_graphs"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=["fp16", "bf16", "fp32"])
@pytest.mark.parametrize("shape", [128, 512, 1024, 2048, 4096], ids=lambda x: f"shape={x}x{x}")
@pytest.mark.parametrize("implementation", implementations_rms_norm.keys())
def test_benchmark_rms_norm(benchmark, shape: int, dtype, cuda_graphs: bool, implementation: str):
Expand Down
Loading