diff --git a/src/kernl/implementations/linear_layer.py b/src/kernl/implementations/linear_layer.py index 68ad9144..fb03fce4 100644 --- a/src/kernl/implementations/linear_layer.py +++ b/src/kernl/implementations/linear_layer.py @@ -89,7 +89,7 @@ def get_configs_io_bound(): } ) @triton.jit -def kernel_fma( +def kernel_linear( C, # Pointers to matrices ACT_INPUTS, A, @@ -124,19 +124,56 @@ def kernel_fma( HAS_BIAS: tl.constexpr, SHOULD_SAVE_ACT_INPUTS: tl.constexpr, ACTIVATION: tl.constexpr, + # quantization scalers + ALPHA_SCALER: tl.constexpr, + BETA_SCALER: tl.constexpr, + ACC_TYPE: tl.constexpr, ): """ - Kernel for computing Out = activation(A x W + C) + Matrix multiplication kernel with fused activation and bias. + Out = activation(alpha * A x W + beta * C) - Input has shape (M, K) - Weight has shape (K, N) - - Bias has shape (N,) + - Bias has shape (N,) -> The bias is added to each row of the matmul output. - Output has shape (M, N) - ActInputs (optional) has shape (M, N) 'ActInputs' optionally saves the A x W + C intermediate for backward computations This kernel will consolidate over K + + :param C: Output matrix + :param ACT_INPUTS: (Optional) tensor to save the activation inputs (for backward) + :param A: Input matrix A (inputs) + :param B: Input matrix B (weights) (transposed) + :param bias: Bias vector + :param M: Number of rows in A and C + :param N: Number of columns in B and C + :param K: Number of columns in A and rows in B + :param CACHE_KEY_M: Cache key for M + :param CACHE_KEY_N: Cache key for N + :param CACHE_KEY_K: Cache key for K + :param output_m_stride: Stride for output matrix C + :param output_n_stride: Stride for output matrix C + :param act_inputs_m_stride: Stride for activation inputs matrix ACT_INPUTS + :param act_inputs_n_stride: Stride for activation inputs matrix ACT_INPUTS + :param a_m_stride: Stride for input matrix A + :param a_k_stride: Stride for input matrix A + :param b_n_stride: Stride for input matrix B + :param b_k_stride: Stride for input matrix B + :param BLOCK_M: Block size in the M dimension + :param GROUP_M: Number of blocks in the M dimension + :param BLOCK_N: Block size in the N dimension + :param BLOCK_K: Block size in the K dimension + :param SPLIT_K: Number of blocks in the K dimension + :param K_LOAD_MASK_NEEDED: Whether or not to use a mask when loading from B + :param HAS_BIAS: Whether or not to add a bias to the result + :param SHOULD_SAVE_ACT_INPUTS: Whether or not to save the activation inputs + :param ACTIVATION: Activation function to apply + :param ACC_TYPE: Accumulation type + + :return: None """ program_idx = tl.program_id(axis=0) @@ -164,11 +201,7 @@ def kernel_fma( A = A + (m_offs[:, None] * a_m_stride + k_range_offs[None, :] * a_k_stride) B = B + (k_range_offs[:, None] * b_k_stride + n_offs[None, :] * b_n_stride) - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - - if HAS_BIAS: - bias = tl.load(bias + n_offs, mask=n_offs < N, other=0.0).to(tl.float32) - acc += bias[None, :] + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) for k in range(K, 0, -BLOCK_K): if K_LOAD_MASK_NEEDED: @@ -182,6 +215,15 @@ def kernel_fma( A += BLOCK_K * a_k_stride B += BLOCK_K * b_k_stride + if ALPHA_SCALER != 1.0: + acc *= ALPHA_SCALER + + if HAS_BIAS: + bias = tl.load(bias + n_offs, mask=n_offs < N, other=0.0) # .to(ACC_TYPE) # TODO fix when Triton updated + if BETA_SCALER != 1.0: + bias *= BETA_SCALER + acc += bias[None, :] + # optional: save the activation inputs if SHOULD_SAVE_ACT_INPUTS: act_in_ptrs = ACT_INPUTS + m_offs[:, None] * act_inputs_m_stride + n_offs[None, :] * act_inputs_n_stride @@ -212,7 +254,10 @@ def forward( weight: torch.Tensor, bias: Optional[torch.Tensor], activation: str, + alpha_scaler: float, + beta_scaler: float, act_inputs: Optional[torch.Tensor], + output: Optional[torch.Tensor], ) -> torch.Tensor: """ Compute e = activation(x @ weight + bias). @@ -221,16 +266,19 @@ def forward( :param x: input tensor :param weight: weight matrix :param bias: an optional bias tensor - :param activation: Activation name. Needs to be a Triton kernel. + :param activation: Activation name (relu, tanh, gelu, fast_gelu) :param act_inputs: an optional tensor to save the activation inputs (for backward) + :param alpha_scaler: alpha scaler (to be appled on mamtul output) + :param beta_scaler: beta scaler (to be applied on bias) + :param output: an optional output tensor :return: result tensor """ x_ = x if x.ndim == 2 else x.flatten(0, 1) assert x.dtype == weight.dtype, f"Input and weight must have the same dtype, got {x.dtype} and {weight.dtype}" - if bias is not None: - assert x.dtype == bias.dtype, f"Input and bias must have the same dtype, got {x.dtype} and {bias.dtype}" - assert x_.shape[1] == weight.shape[1], f"Incompatible dimensions: {x_.shape} - {weight.shape}" + # if bias is not None: + # assert x.dtype == bias.dtype, f"Input and bias must have the same dtype, got {x.dtype} and {bias.dtype}" + assert x_.shape[1] == weight.shape[1], f"Incompatible dimensions: {x_.shape} / {weight.shape}" assert bias is None or bias.is_contiguous() assert bias is None or bias.shape[0] == weight.shape[0], "Incompatible dimensions in between weight and bias" @@ -239,13 +287,16 @@ def forward( M, K = x_.shape N, K = weight.shape - outputs = torch.empty((M, N), device=x.device, dtype=x.dtype) - + if output is None: + output = torch.empty((M, N), device=x.device, dtype=x.dtype) + else: + output = output if output.ndim == 2 else output.flatten(0, 1) + acc_type = tl.float32 if output.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 # 1D launch kernel where each block gets its own program. grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) # noqa - kernel_fma[grid]( - outputs, + kernel_linear[grid]( + output, act_inputs, x_, weight, # data ptrs @@ -256,8 +307,8 @@ def forward( M // 32, # key for triton cache (limit number of compilations) N // 32, K // 32, - output_m_stride=outputs.stride(0), # strides - output_n_stride=outputs.stride(1), + output_m_stride=output.stride(0), # strides + output_n_stride=output.stride(1), act_inputs_m_stride=act_inputs.stride(0) if act_inputs is not None else 0, act_inputs_n_stride=act_inputs.stride(1) if act_inputs is not None else 0, a_m_stride=x_.stride(0), @@ -267,10 +318,13 @@ def forward( HAS_BIAS=bias is not None, # optional fused bias SHOULD_SAVE_ACT_INPUTS=act_inputs is not None, # optional save activation inputs ACTIVATION=activation if not None else x, # optional fused activation + ALPHA_SCALER=alpha_scaler, # optional alpha scaler (quantization) + BETA_SCALER=beta_scaler, # optional beta scaler (quantization) + ACC_TYPE=acc_type, # accumulator type GROUP_M=8, # speed optimization: group the programs ) - outputs = outputs if x.ndim == 2 else outputs.reshape(x.shape[0], -1, N) + outputs = output if x.ndim == 2 else output.reshape(x.shape[0], -1, N) ctx.save_for_backward(weight, bias, x) return outputs @@ -281,5 +335,8 @@ def linear_layer( bias: Optional[torch.Tensor], activation="", act_inputs: Optional[torch.Tensor] = None, + alpha_scaler=1.0, + beta_scaler=1.0, + output: Optional[torch.Tensor] = None, ) -> torch.Tensor: - return LinearLayer.apply(x, weight, bias, activation, act_inputs) + return LinearLayer.apply(x, weight, bias, activation, alpha_scaler, beta_scaler, act_inputs, output) diff --git a/src/kernl/implementations/linear_layer_quant.py b/src/kernl/implementations/linear_layer_quant.py new file mode 100644 index 00000000..1219f60a --- /dev/null +++ b/src/kernl/implementations/linear_layer_quant.py @@ -0,0 +1,200 @@ +# Copyright 2022 Lefebvre Sarrut +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# code inspired from torch-int pacakge +# https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py + +import torch + +from kernl.implementations.linear_layer import linear_layer + + +@torch.no_grad() +def quantize_per_tensor_absmax(t: torch.Tensor): + scale = t.abs().max() / 127 + if not t.is_cuda: + # half rounding is not supported on CPU + t = t.float() + # use inplace operation to save memory + t.div_(scale).round_() + t_q = t.to(torch.int8) + return t_q, scale + + +class W8A8B8O8Linear(torch.nn.Module): + # For qkv_proj + def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + self.register_buffer( + "weight", + torch.randint(-127, 127, (self.out_features, self.in_features), dtype=torch.int8, requires_grad=False), + ) + self.register_buffer("bias", torch.zeros((1, self.out_features), dtype=torch.int8, requires_grad=False)) + self.register_buffer("a", torch.tensor(alpha)) + self.register_buffer("b", torch.tensor(beta)) + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.weight = self.weight.to(*args, **kwargs) + self.bias = self.bias.to(*args, **kwargs) + return self + + @torch.no_grad() + def forward(self, x): + x_shape = x.shape + x = x.view(-1, x_shape[-1]) + y = torch.empty((x.shape[0], self.weight.shape[0]), device=x.device, dtype=torch.int8) + linear_layer( + x=x, + weight=self.weight, + bias=self.bias, + activation="", + act_inputs=None, + alpha_scaler=self.a.item(), + beta_scaler=self.b.item(), + output=y, + ) + y = y.view(*x_shape[:-1], -1) + return y + + @staticmethod + def from_float(module: torch.nn.Linear, input_scale, output_scale): + int8_module = W8A8B8O8Linear(module.in_features, module.out_features) + int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight) + int8_bias, bias_scale = quantize_per_tensor_absmax(module.bias) + alpha = input_scale * weight_scale / output_scale + beta = bias_scale / output_scale + int8_module.weight = int8_weight + int8_module.bias = int8_bias + int8_module.a = alpha + int8_module.b = beta + return int8_module + + +class W8A8B8O8LinearReLU(torch.nn.Module): + # For fc1 + def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + self.register_buffer( + "weight", + torch.randint(-127, 127, (self.out_features, self.in_features), dtype=torch.int8, requires_grad=False), + ) + self.register_buffer("bias", torch.zeros((1, self.out_features), dtype=torch.int8, requires_grad=False)) + self.register_buffer("a", torch.tensor(alpha)) + self.register_buffer("b", torch.tensor(beta)) + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.weight = self.weight.to(*args, **kwargs) + self.bias = self.bias.to(*args, **kwargs) + return self + + @torch.no_grad() + def forward(self, x): + x_shape = x.shape + x = x.view(-1, x_shape[-1]) + y = torch.empty((x.shape[0], self.weight.shape[0]), device=x.device, dtype=torch.int8) + + linear_layer( + x=x, + weight=self.weight, + bias=self.bias, + activation="relu", + act_inputs=None, + alpha_scaler=self.a.item(), + beta_scaler=self.b.item(), + output=y, + ) + y = y.view(*x_shape[:-1], -1) + return y + + @staticmethod + def from_float(module: torch.nn.Linear, input_scale, output_scale): + # TODO: add zero-point to prevent the bit waste + int8_module = W8A8B8O8LinearReLU(module.in_features, module.out_features) + int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight) + int8_bias, bias_scale = quantize_per_tensor_absmax(module.bias) + alpha = input_scale * weight_scale / output_scale + beta = bias_scale / output_scale + int8_module.weight = int8_weight + int8_module.bias = int8_bias + int8_module.a = alpha + int8_module.b = beta + return int8_module + + +class W8A8BFP32OFP32Linear(torch.nn.Module): + # For fc2 and out_proj + def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + self.register_buffer( + "weight", + torch.randint(-127, 127, (self.out_features, self.in_features), dtype=torch.int8, requires_grad=False), + ) + self.register_buffer("bias", torch.zeros((1, self.out_features), dtype=torch.float32, requires_grad=False)) + self.register_buffer("a", torch.tensor(alpha)) + + def _apply(self, fn): + # prevent the bias from being converted to half + super()._apply(fn) + self.bias = self.bias.to(torch.float32) + return self + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.weight = self.weight.to(*args, **kwargs) + self.bias = self.bias.to(*args, **kwargs) + self.bias = self.bias.to(torch.float32) + return self + + @torch.no_grad() + def forward(self, x): + x_shape = x.shape + x = x.view(-1, x_shape[-1]) + self.bias = self.bias.to(torch.float32) + y = torch.empty((x.shape[0], self.weight.shape[0]), device=x.device, dtype=torch.float32) + + linear_layer( + x=x, + weight=self.weight, + bias=self.bias, + activation="", + act_inputs=None, + alpha_scaler=self.a.item(), + beta_scaler=1.0, + output=y, + ) + y = y.view(*x_shape[:-1], -1) + return y + + @staticmethod + def from_float(module: torch.nn.Linear, input_scale): + int8_module = W8A8BFP32OFP32Linear(module.in_features, module.out_features) + int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight) + alpha = input_scale * weight_scale + int8_module.weight = int8_weight + int8_module.bias = module.bias.to(torch.float32) + int8_module.a = alpha + int8_module.input_scale = input_scale + int8_module.weight_scale = weight_scale + return int8_module diff --git a/test/test_linear_layer.py b/test/test_linear_layer.py index ba777d2d..cac052ad 100644 --- a/test/test_linear_layer.py +++ b/test/test_linear_layer.py @@ -41,7 +41,9 @@ def get_pytorch_activation(activation: str) -> Callable: "pytorch": lambda weight, bias, activation: lambda x: get_pytorch_activation(activation)( torch.nn.functional.linear(x, weight, bias) ), - "triton": lambda weight, bias, activation: lambda x: linear_layer(x, weight, bias, activation), + "triton": lambda weight, bias, activation: lambda x: linear_layer( + x=x, weight=weight, bias=bias, activation=activation + ), } diff --git a/test/test_linear_layer_quant.py b/test/test_linear_layer_quant.py new file mode 100644 index 00000000..57cfb165 --- /dev/null +++ b/test/test_linear_layer_quant.py @@ -0,0 +1,100 @@ +# Copyright 2022 Lefebvre Sarrut +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest +import torch + +from conftest import assert_all_close, set_seed + +from kernl.implementations.linear_layer import linear_layer +from kernl.implementations.linear_layer_quant import W8A8B8O8Linear, W8A8B8O8LinearReLU, W8A8BFP32OFP32Linear + + +@set_seed() +@pytest.mark.parametrize("implementation", ["triton", "pytorch"]) +def test_quant_linear_a8_w8_b32_o32(benchmark, implementation): + alpha, beta = 0.01, 0.0001 + B, M, N = 128, 512, 1024 + min_int8 = torch.iinfo(torch.int8).min + max_int8 = torch.iinfo(torch.int8).max + min_bias = int(torch.finfo(torch.float16).min) # because bias will be converted to fp16 for benchmark reason + max_bias = int(torch.finfo(torch.float16).max) + weight = torch.randint(min_int8, max_int8, (N, M), dtype=torch.int8, device="cuda") + bias = torch.randint(min_bias, max_bias, (N,), dtype=torch.int32, device="cuda") + x = torch.randint(min_int8, max_int8, (B, M), dtype=torch.int8, device="cuda") + linear = torch.nn.Linear(M, N, bias=True) + linear.weight.data = weight.half() * alpha + linear.bias.data = bias.half() * beta + y_pytorch = linear(x.half()) + assert torch.all(torch.isfinite(y_pytorch)) + + if implementation == "triton": + y_triton = torch.zeros((B, N), device="cuda") + y_triton = linear_layer( + x=x, + weight=weight, + bias=bias, + activation="", + act_inputs=None, + alpha_scaler=alpha, + beta_scaler=beta, + output=y_triton, + ) + assert_all_close(y_pytorch, y_triton.half(), rtol=0, atol=4) # not eq as baseline is computed with floats + fn = lambda: linear_layer( # noqa: E731 + x=x, + weight=weight, + bias=bias, + activation="", + act_inputs=None, + alpha_scaler=alpha, + beta_scaler=beta, + output=y_triton, + ) + elif implementation == "pytorch": + x = x.half() + fn = lambda: linear(x) # noqa: E731 + else: + raise ValueError(f"Unknown implementation: {implementation}") + + benchmark(fn) + + +@set_seed() +@torch.no_grad() +@pytest.mark.parametrize("implementation", ["w8a8b8o8", "w8a8b8o8_relu", "w8a8bfp32ofp32"]) +def test_w8a8b8o8_linear_relu(implementation): + B, M, N = 128, 512, 1024 + x = torch.randn(B, M) + x_scale = x.abs().max() / 127 + qx = (x / x_scale).round().to(torch.int8) + linear = torch.nn.Linear(M, N, bias=True) + y_pytorch = linear(x) + if implementation == "w8a8b8o8_relu": + y_pytorch = y_pytorch.clamp(min=0) + if implementation != "w8a8bfp32ofp32": + y_scale = y_pytorch.abs().max() / 127 + # linear_quant = W8A8B8O8LinearReLU.from_float(linear, x_scale, y_scale).cuda() + if implementation == "w8a8b8o8": + linear_quant = W8A8B8O8Linear.from_float(linear, x_scale, y_scale).cuda() + elif implementation == "w8a8b8o8_relu": + linear_quant = W8A8B8O8LinearReLU.from_float(linear, x_scale, y_scale).cuda() + elif implementation == "w8a8bfp32ofp32": + linear_quant = W8A8BFP32OFP32Linear.from_float(linear, x_scale).cuda() + else: + raise ValueError(f"Unknown implementation: {implementation}") + q_y = linear_quant(qx.cuda()).cpu() + y_quant = q_y * y_scale if implementation != "w8a8bfp32ofp32" else q_y + assert_all_close(y_pytorch, y_quant, rtol=0, atol=1e-1)