From 2ac7e7df0eaf8abd9f1118b3bc22b5aceb58ba57 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Wed, 2 Apr 2025 14:40:40 -0700 Subject: [PATCH] Preshuffled BF16I4 Gemm Kernel (#3913) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/1003 This diff adds a preshuffled BF16I4 mixed dtype kernel using cutlass. Performance is quite compelling and shows substantial speedups for some shapes compared to bf16 x bf16 gemm backed by cublas. Notably, this preshuffle approach is 1.5-2X faster than the standard bf16i4 gemm for most shapes. Compared to other mixed dtype kernels like marlin and machete, we see that this new kernel is probably the best average performer. {F1976677491} Reviewed By: jianyuh Differential Revision: D72270467 --- .../experimental/gen_ai/bench/quantize_ops.py | 55 ++- .../experimental/gen_ai/gen_ai/quantize.py | 108 +++-- .../quantize/cutlass_extensions/bf16i4bf16.cu | 412 ++++++++++++++++++ .../cutlass_extensions/bf16i4bf16_rowwise.cu | 297 ------------- .../cutlass_extensions/f8i4bf16_shuffled.cu | 4 +- .../cutlass_extensions/mixed_dtype_utils.cu | 59 ++- .../gen_ai/src/quantize/quantize.cpp | 25 +- .../gen_ai/test/quantize/quantize_test.py | 2 +- 8 files changed, 609 insertions(+), 353 deletions(-) create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16i4bf16.cu delete mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16i4bf16_rowwise.cu diff --git a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py index 333cd42c0d..07945414f9 100644 --- a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +++ b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py @@ -1426,7 +1426,7 @@ def cuda(self) -> bool: class F8I4ShuffledGemm(QuantizeOpBase): def preprocess(self, x, w): # Prequantize and pack weights. - wq, row_scale, group_scale = quantize_int4_preshuffle(w) + wq, (group_scale, row_scale) = quantize_int4_preshuffle(w) return x, wq, row_scale, group_scale def quantize(self, x, wq, row_scale, group_scale): @@ -1470,6 +1470,49 @@ def cuda(self) -> bool: return True +@register_quantize_op +class BF16I4ShuffledGemm(QuantizeOpBase): + def preprocess(self, x, w): + # Prequantize and pack weights. + wq, (group_scale, group_zero) = quantize_int4_preshuffle(w, dtype="bf16") + return x, wq, group_scale, group_zero + + def quantize(self, x, wq, group_scale, group_zero): + # No extra action required. + return x, wq, group_scale, group_zero + + def compute(self, x, wq, group_scale, group_zero): + # Handle batched cases by looping over each batch. + if x.dim() == 3: + B, M, _ = x.shape + _, N, _ = wq.shape + y = torch.empty((B, M, N), device=x.device, dtype=torch.bfloat16) + for i in range(B): + y[i] = torch.ops.fbgemm.bf16i4bf16_shuffled( + x[i], wq[i], group_scale[i], group_zero[i] + ) + return y + # Otherwise run gemm normally. + return torch.ops.fbgemm.bf16i4bf16_shuffled(x, wq, group_scale, group_zero) + + def quantize_and_compute(self, x, wq, group_scale, group_zero): + x, wq, group_scale, group_zero = self.quantize(x, wq, group_scale, group_zero) + return self.compute(x, wq, group_scale, group_zero) + + @property + def name(self) -> str: + return "cutlass_bf16i4_preshuffle" + + @property + def hip(self) -> bool: + # Not yet supported on AMD. + return False + + @property + def cuda(self) -> bool: + return True + + @register_quantize_op class F8I4ShuffledGroupedGemm(QuantizeOpBase): """ @@ -1485,7 +1528,8 @@ def preprocess(self, x, w): m_sizes = torch.tensor(m_values).to(dtype=torch.int32, device=x[0].device) # Quantize weights. # TODO Only rowwise scaling is currently supported. This needs to be fixed. - wq, row_scale, group_scale = zip(*[quantize_int4_preshuffle(i) for i in w]) + wq, scales = zip(*[quantize_int4_preshuffle(i) for i in w]) + group_scale, row_scale = zip(*scales) # Group weights as single tensor. wq = torch.stack(wq, dim=0).contiguous() row_scale = torch.stack(row_scale, dim=0).contiguous() @@ -1580,7 +1624,12 @@ def quantize(self, x, w): wq, w_scale, w_zp = self._int4_row_quantize(w) # Pack int4 values together. wq = self._pack_int4(wq) - return x.to(torch.bfloat16), wq, w_scale, w_zp + return ( + x.to(torch.bfloat16), + wq, + w_scale, + w_zp, + ) def compute(self, x, wq, w_scale, w_zp): return torch.ops.fbgemm.bf16i4bf16_rowwise(x, wq, w_scale, w_zp) diff --git a/fbgemm_gpu/experimental/gen_ai/gen_ai/quantize.py b/fbgemm_gpu/experimental/gen_ai/gen_ai/quantize.py index 89b4441599..2959c83366 100644 --- a/fbgemm_gpu/experimental/gen_ai/gen_ai/quantize.py +++ b/fbgemm_gpu/experimental/gen_ai/gen_ai/quantize.py @@ -29,6 +29,34 @@ def pack_int4(x: torch.Tensor) -> torch.Tensor: return torch.bitwise_or(low_x, high_x).contiguous() +def int4_row_quantize_zp( + x: torch.Tensor, + group_size: int = 128, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + n_bit = 4 # Number of target bits. + to_quant = x.reshape(-1, group_size).to(torch.float) + + max_val = to_quant.amax(dim=1, keepdim=True) + min_val = to_quant.amin(dim=1, keepdim=True) + max_int = 2**n_bit - 1 + min_int = 0 + scales = (max_val - min_val).clamp(min=1e-6) / max_int + + zeros = min_val + scales * (2 ** (n_bit - 1)) + + out = to_quant.sub(min_val).div(scales).round().clamp_(min_int, max_int) + + # Recenter output and move to int8. + out = (out - 2 ** (n_bit - 1)).to(dtype=torch.int8).reshape(x.shape) + + # Cutlass expects column major layout for scale and zero point, + # so we transpose here and make them contiguous. + scales = scales.view(x.shape[0], -1).t().contiguous() + zeros = zeros.view(x.shape[0], -1).t().contiguous() + + return out, scales, zeros + + def int4_row_quantize( x: torch.Tensor, group_size: int = 128, @@ -63,8 +91,8 @@ def int4_row_quantize( def quantize_int4_preshuffle( - w: torch.Tensor, group_size: int = 128 -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + w: torch.Tensor, group_size: int = 128, dtype: str = "fp8" +) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Quantizes an input weight tensor to int4 using preshuffling and scale packing. This function is intended to be used with fbgemms mixed dtype kernels and is expected @@ -73,47 +101,57 @@ def quantize_int4_preshuffle( Args: w (Tensor): [N, K] Higher precision weight tensor to quantize. May optionally have a batch dimension. group_size (int): Number of elements to calculate group scale for, must be at least 128. + dtype (torch.dtype): Type of corresponding activations. Must be fp8 or bf16. Returns: wq (Tensor): [N, K // 2] Quantized int4 weight tensor packed into int8 elements. - row_scale (Tensor): [N] FP32 Scale per row of the weight tensor. - group_scale (Tensor): [K / group_size, 8, N] FP8 Scale per group of the weight tensor. + scales (Tuple[Tensor]): Scale tensors for the specified activation type. When FP8 is used, + scales is a tuple of row_scale ([N]) and group_scale ([K / group_size, 8, N]). When BF16 is + used, scales is a tuple of group_scale([K / group_size, N]) and group_zero ([K / group_size, N]) """ - def _quantize(w: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - # Start by lowering weights to FP8 and producing row scales. - wq, row_scale = quantize_fp8_row(w) - - # Now reduce to INT4. - wq, group_scale = int4_row_quantize(wq, group_size) - # Reduce group scale to FP8. - group_scale = group_scale.to(torch.float8_e4m3fn) - - # Take quantized weights and pack them efficiently. - wq = pack_int4(wq) - - # Finally pack weights and scales into efficient preshuffled format. - wq, group_scale = torch.ops.fbgemm.preshuffle_i4(wq, group_scale) - - return wq, row_scale, group_scale + def _quantize( + w: torch.Tensor, dtype: str = "fp8" + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + + if dtype == "fp8": + # Start by lowering weights to FP8 and producing row scales. + wq, row_scale = quantize_fp8_row(w) + + # Now reduce to INT4. + wq, group_scale = int4_row_quantize(wq, group_size) + # Reduce group scale to FP8. + group_scale = group_scale.to(torch.float8_e4m3fn) + # Take quantized weights and pack them efficiently. + wq = pack_int4(wq) + # Finally pack weights and scales into efficient preshuffled format. + wq, group_scale = torch.ops.fbgemm.preshuffle_i4(wq, group_scale) + return wq, (group_scale, row_scale) + + elif dtype == "bf16": + wq, group_scale, group_zero = int4_row_quantize_zp(w, group_size) + # Set scales to activation type. + group_scale = group_scale.to(torch.bfloat16) + group_zero = group_zero.to(torch.bfloat16) + # Take quantized weights and pack them efficiently. + wq = pack_int4(wq) + # Finally pack weights and scales into efficient preshuffled format. + wq, group_scale = torch.ops.fbgemm.preshuffle_i4(wq, group_scale) + return wq, (group_scale, group_zero) + else: + raise NotImplementedError("Only fp8 and bf16 activations supported.") if w.ndim >= 3: orig_shape = w.shape # Flatten to 3 dimensions then iterate over batches. - w = w.view(-1, *w.shape[1:]) - w.unbind(dim=0) - wq = [] - row_scale = [] - group_scale = [] - for batch in w: - wq_, row_scale_, group_scale_ = _quantize(batch) - wq.append(wq_) - row_scale.append(row_scale_) - group_scale.append(group_scale_) + wq, scales = zip(*[_quantize(i, dtype=dtype) for i in w]) wq = torch.stack(wq).view(*orig_shape[:-2], *wq[0].shape) - row_scale = torch.stack(row_scale).view(*orig_shape[:-2], *row_scale[0].shape) - group_scale = torch.stack(group_scale).view( - *orig_shape[:-2], *group_scale[0].shape + # Decompose then stack scales back into a tuple. + a_scales, b_scales = zip(*scales) + scales = ( + torch.stack(a_scales).view(*orig_shape[:-2], *a_scales[0].shape), + torch.stack(b_scales).view(*orig_shape[:-2], *b_scales[0].shape), ) else: - wq, row_scale, group_scale = _quantize(w) - return wq, row_scale, group_scale + wq, scales = _quantize(w, dtype=dtype) + + return wq, scales diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16i4bf16.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16i4bf16.cu new file mode 100644 index 0000000000..4dd335cdc0 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16i4bf16.cu @@ -0,0 +1,412 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/mixed_dtype_utils.hpp" +#include "cutlass/util/packed_stride.hpp" + +namespace fbgemm_gpu { + +#if CUDART_VERSION >= 12000 + +template < + bool SHUFFLE, + typename SCALE_TYPE, + int TB_M, + int TB_N, + int TBS_M, + int TBS_N, + int TBS_K, + bool COOP> +at::Tensor _bf16i4bf16( + at::Tensor X, + at::Tensor W, + at::Tensor w_scale_group, + at::Tensor w_zero_group, + at::Tensor Y) { + // Get shape information from input tensors. + int M = size_to_dim_(X.dim() - 1, X.sizes()); + int K = X.size(-1); + int N = size_to_dim_(W.dim() - 1, W.sizes()); + int num_groups = w_scale_group.size(0); + TORCH_CHECK( + w_zero_group.size(0) == num_groups, + "Scales and zeros must be the same shape."); + int group_size = K / num_groups; + + // Define input types. + using MmaType = cutlass::bfloat16_t; + using QuantType = cutlass::int4b_t; + constexpr int TileShapeK = 128 * 8 / cute::sizeof_bits::value; + + // A Matrix configuration. + using ElementA = MmaType; + using LayoutA = cutlass::layout::RowMajor; + constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + + // B Matrix Configuration. + using ElementB = QuantType; + using LayoutB = cutlass::layout::ColumnMajor; + constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + + // We need to manually swap and transpose inputs. Unclear how required this is + // though. + using LayoutA_Transpose = + typename cutlass::layout::LayoutTranspose::type; + using LayoutB_Transpose = + typename cutlass::layout::LayoutTranspose::type; + + using StrideA = cutlass::detail::TagToStrideA_t; + using StrideB = cutlass::detail::TagToStrideB_t; + + // Define layout for shuffled weight tensor. + using ValueShuffle = cute::Layout< + cute::Shape, + cute::Stride>; // order [0,2,4,6,1,3,5,7] + int constexpr NumShuffleAtoms = 1; + using MmaAtomShape = + cute::Layout>>; + using LayoutAtomQuant = decltype(cutlass::compute_memory_reordering_atom< + MmaType, + MmaAtomShape, + ValueShuffle>()); + using LayoutB_Reordered = decltype(cute::tile_to_shape( + LayoutAtomQuant{}, cute::Layout, StrideB>{})); + + using B_Layout = + cute::conditional_t; + + using ElementScale = SCALE_TYPE; + using ElementZero = ElementScale; + + // Output Matrix configuration. + using ElementC = cutlass::bfloat16_t; + using LayoutC = cutlass::layout::RowMajor; + constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + + // Core kernel configurations + using ElementAccumulator = float; + using ElementCompute = float; + using ArchTag = cutlass::arch::Sm90; + using OperatorClass = cutlass::arch::OpClassTensorOp; + using TileShape = + cute::Shape, cute::Int, cute::Int>; + using ClusterShape = + cute::Shape, cute::Int, cute::Int>; + using KernelSchedule = cute::conditional_t< + COOP, + cutlass::gemm::KernelTmaWarpSpecializedCooperative, + cutlass::gemm::KernelTmaWarpSpecialized>; + // Might be the only epilogue schedule that supports swap + transpose. + using EpilogueSchedule = cute::conditional_t< + COOP, + cutlass::epilogue::TmaWarpSpecializedCooperative, + cutlass::epilogue::TmaWarpSpecialized>; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, + cutlass::arch::OpClassTensorOp, + TileShape, + ClusterShape, + EpilogueTileType, + ElementAccumulator, + ElementAccumulator, + void, // Indicate there is no beta scaling. + typename cutlass::layout::LayoutTranspose::type, + AlignmentC, + ElementC, + typename cutlass::layout::LayoutTranspose::type, + AlignmentC, + EpilogueSchedule>::CollectiveOp; + + using CollectiveMainloopShuffled = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + cute::tuple, + B_Layout, + AlignmentB, + ElementA, + LayoutA_Transpose, + AlignmentA, + ElementAccumulator, + TileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernelShuffled = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloopShuffled, + CollectiveEpilogue>; + + using GemmShuffled = + cutlass::gemm::device::GemmUniversalAdapter; + + using StrideC = typename GemmKernelShuffled::StrideC; + + /// Initialization + auto shape_B = cute::make_shape(N, K, 1); + StrideA stride_A = + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, 1)); + StrideB stride_B = + cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, 1)); + StrideC stride_C = + cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(N, M, 1)); + LayoutB_Reordered layout_B_reordered = + cute::tile_to_shape(LayoutAtomQuant{}, shape_B); + + using stride_type = cute::conditional_t; + stride_type B_stride; + if constexpr (SHUFFLE) { + B_stride = layout_B_reordered; + } else { + B_stride = stride_B; + } + + using StrideS = typename CollectiveMainloopShuffled::StrideScale; + StrideS stride_S = cutlass::make_cute_packed_stride( + StrideS{}, cute::make_shape(N, num_groups, 1)); + + // Define Gemm arguments. + typename GemmShuffled::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {N, M, K, 1}, + {reinterpret_cast(W.data_ptr()), + B_stride, + reinterpret_cast(X.data_ptr()), + stride_A, + reinterpret_cast(w_scale_group.data_ptr()), + stride_S, + group_size, + reinterpret_cast(w_zero_group.data_ptr())}, + {{}, + nullptr, + stride_C, + reinterpret_cast(Y.data_ptr()), + stride_C}}; + + // Launch the workload. + GemmShuffled gemm; + + // Using the arguments, query for extra workspace required for matrix + // multiplication computation + int workspace_size = GemmShuffled::get_workspace_size(arguments); + + // Allocate workspace memory + at::Tensor workspace = + at::empty(workspace_size, X.options().dtype(at::kByte)); + + // Check the problem size is supported or not + cutlass::Status status = gemm.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot implement"); + } + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm.initialize(arguments, workspace.data_ptr()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot initialize"); + } + + status = gemm(at::cuda::getCurrentCUDAStream()); + + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error( + std::string("cutlass cannot run") + + cutlass::cutlassGetStatusString(status)); + } + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return Y; +} + +template +at::Tensor bf16i4bf16_dispatch( + at::Tensor X, + at::Tensor W, + at::Tensor w_scale_group, + at::Tensor w_zero_group) { + int M = size_to_dim_(X.dim() - 1, X.sizes()); + int K = X.size(-1); + int N = size_to_dim_(W.dim() - 1, W.sizes()); + // Check input types and shapes. + TORCH_CHECK( + X.is_cuda() && X.is_contiguous() && X.dtype() == at::kBFloat16, + "X must be BF16 and contiguous on GPU."); + TORCH_CHECK( + W.size(-1) == K / 2 && W.is_cuda() && W.is_contiguous() && + W.dtype() == at::kChar, + "W should be int8 (which represent two int4 values), have shape [..., N, K/2], " + "and be contiguous on GPU."); + // Make sure group scales and zeros are in proper format. + TORCH_CHECK( + w_scale_group.dim() == 2 && w_scale_group.size(1) == N, + "Group scales are expected to have shape [num_groups, N]."); + + // Allocate output or return an empty tensor if input is empty. + if (M == 0 || N == 0 || K == 0) { + return at::zeros({M, N}, X.options().dtype(at::kBFloat16)); + } + at::Tensor Y = at::empty({M, N}, X.options().dtype(at::kBFloat16)); + + // Use shape heuristics to dispatch to optimized kernel configuration. + if (M <= 16) { + return _bf16i4bf16( + X, W, w_scale_group, w_zero_group, Y); + } else if (M <= 32) { + if (N <= 4096) { + return _bf16i4bf16( + X, W, w_scale_group, w_zero_group, Y); + } else { + return _bf16i4bf16( + X, W, w_scale_group, w_zero_group, Y); + } + } else if (M <= 64) { + if (N <= 2048) { + return _bf16i4bf16( + X, W, w_scale_group, w_zero_group, Y); + } else if (N <= 4096) { + return _bf16i4bf16( + X, W, w_scale_group, w_zero_group, Y); + } else { + return _bf16i4bf16( + X, W, w_scale_group, w_zero_group, Y); + } + } else if (M <= 128) { + if (N <= 1024) { + return _bf16i4bf16( + X, W, w_scale_group, w_zero_group, Y); + } else if (N <= 2048) { + return _bf16i4bf16( + X, W, w_scale_group, w_zero_group, Y); + } else if (N <= 4096) { + return _bf16i4bf16( + X, W, w_scale_group, w_zero_group, Y); + } else { + return _bf16i4bf16( + X, W, w_scale_group, w_zero_group, Y); + } + } else if (M <= 256) { + if (N <= 1024) { + return _bf16i4bf16( + X, W, w_scale_group, w_zero_group, Y); + } else if (N <= 2048) { + return _bf16i4bf16( + X, W, w_scale_group, w_zero_group, Y); + } else if (N <= 4096) { + return _bf16i4bf16( + X, W, w_scale_group, w_zero_group, Y); + } else { + return _bf16i4bf16( + X, W, w_scale_group, w_zero_group, Y); + } + } else if (M <= 512) { + if (N <= 1024) { + return _bf16i4bf16( + X, W, w_scale_group, w_zero_group, Y); + } else if (N <= 2048) { + return _bf16i4bf16( + X, W, w_scale_group, w_zero_group, Y); + } else if (N <= 4096) { + return _bf16i4bf16( + X, W, w_scale_group, w_zero_group, Y); + } else { + return _bf16i4bf16( + X, W, w_scale_group, w_zero_group, Y); + } + } else if (M <= 1024) { + if (N <= 1024) { + return _bf16i4bf16( + X, W, w_scale_group, w_zero_group, Y); + } else if (N <= 2048) { + return _bf16i4bf16( + X, W, w_scale_group, w_zero_group, Y); + } else { + return _bf16i4bf16( + X, W, w_scale_group, w_zero_group, Y); + } + } else { + if (M <= 2048 && N <= 1024) { + return _bf16i4bf16( + X, W, w_scale_group, w_zero_group, Y); + } else { + return _bf16i4bf16( + X, W, w_scale_group, w_zero_group, Y); + } + } +} + +at::Tensor bf16i4bf16_shuffled( + at::Tensor X, + at::Tensor W, + at::Tensor w_scale_group, + at::Tensor w_zero_group) { + if (w_scale_group.dtype() == at::kFloat) { + return bf16i4bf16_dispatch(X, W, w_scale_group, w_zero_group); + } else if (w_scale_group.dtype() == at::kBFloat16) { + return bf16i4bf16_dispatch( + X, W, w_scale_group, w_zero_group); + } else { + TORCH_CHECK(false, "Only fp32 an bf16 scales supported.") + } +} + +at::Tensor bf16i4bf16_rowwise( + at::Tensor X, // BF16 + at::Tensor W, // INT4 + at::Tensor w_scale_group, + at::Tensor w_zero_group) { + if (w_scale_group.dtype() == at::kFloat) { + return bf16i4bf16_dispatch(X, W, w_scale_group, w_zero_group); + } else if (w_scale_group.dtype() == at::kBFloat16) { + return bf16i4bf16_dispatch( + X, W, w_scale_group, w_zero_group); + } else { + TORCH_CHECK(false, "Only fp32 an bf16 scales supported.") + } +} + +#else + +at::Tensor bf16i4bf16_shuffled( + at::Tensor X, + at::Tensor W, + at::Tensor w_scale_group, + at::Tensor w_zero_group) { + throw std::runtime_error( + "CUDA version is older than 12.0"); // requires CUDA>=12 +} + +at::Tensor bf16i4bf16_rowwise( + at::Tensor X, // BF16 + at::Tensor W, // INT4 + at::Tensor w_scale_group, + at::Tensor w_zero_group) { + throw std::runtime_error( + "CUDA version is older than 12.0"); // requires CUDA>=12 +} + +#endif + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16i4bf16_rowwise.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16i4bf16_rowwise.cu deleted file mode 100644 index 61bdab82f7..0000000000 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16i4bf16_rowwise.cu +++ /dev/null @@ -1,297 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include -#include - -// clang-format off -// The fixed ordering of the headers is required for CUTLASS 3.2+ -#include -#include // @manual -#include // @manual -#include // @manual -// clang-format on - -#include "cutlass_extensions/include/kernel_mode.h" - -namespace fbgemm_gpu { - -#if CUDART_VERSION >= 12000 - -template < - int TB_M, - int TB_N, - int TB_K, - int TBS_M, - int TBS_N, - int TBS_K, - bool PONG, - typename WEIGHT_SCALE_DTYPE> -at::Tensor bf16i4bf16_rowwise_impl( - at::Tensor X, // BF16 - at::Tensor WQ, // INT4 - at::Tensor w_scale, - at::Tensor w_zp) { - int M = X.size(0); - int N = WQ.size(0); - int K = X.size(1); - - int num_groups = w_scale.size(0); - - TORCH_CHECK(X.is_cuda() && X.is_contiguous()); - TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous()); - TORCH_CHECK(w_scale.is_cuda() && w_scale.is_contiguous()); - TORCH_CHECK(w_zp.is_cuda() && w_zp.is_contiguous()); - TORCH_CHECK(K >= num_groups && K % num_groups == 0); - - int group_size = K / num_groups; - - auto Y = at::empty({M, N}, X.options().dtype(at::kBFloat16)); - - using ElementInputA = cutlass::bfloat16_t; - using LayoutInputA = cutlass::layout::ColumnMajor; - constexpr int AlignmentInputA = - 128 / - cutlass::sizeof_bits< - ElementInputA>::value; // Memory access granularity/alignment of A - // matrix in units of elements (up to 16 bytes) - - using ElementInputB = cutlass::int4b_t; - using LayoutInputB = cutlass::layout::RowMajor; - constexpr int AlignmentInputB = - 128 / - cutlass::sizeof_bits< - ElementInputB>::value; // Memory access granularity/alignment of B - // matrix in units of elements (up to 16 bytes) - - using ElementScale = WEIGHT_SCALE_DTYPE; - using ElementZeroPoint = WEIGHT_SCALE_DTYPE; - using ElementComputeEpilogue = float; - using ElementAccumulator = float; - - using ElementOutput = cutlass::bfloat16_t; - using LayoutOutput = cutlass::layout::ColumnMajor; - constexpr int AlignmentOutput = - 128 / - cutlass::sizeof_bits< - ElementOutput>::value; // Memory access granularity/alignment of C - // matrix in units of elements (up to 16 bytes) - - using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that - // supports the intended feature - using OperatorClass = cutlass::arch::OpClassTensorOp; - using TileShape = cute::Shape< - cute::Int, - cute::Int, - cute::Int>; // Threadblock-level - // tile size - using ClusterShape = cute::Shape< - cute::Int, - cute::Int, - cute::Int>; // Shape of the - // threadblocks in a - // cluster - using CooperativeSchedule = - cutlass::gemm::KernelTmaWarpSpecializedCooperative; - using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; - using CooperativeEpilogueSchedule = - cutlass::epilogue::TmaWarpSpecializedCooperative; - using PongEpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; - using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; - using MainLoopSchedule = - cute::conditional_t; - using EpilogueSchedule = cute:: - conditional_t; - - using CollectiveEpilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm90, - cutlass::arch::OpClassTensorOp, - TileShape, - ClusterShape, - EpilogueTileType, - ElementAccumulator, - ElementAccumulator, - ElementOutput, - LayoutOutput, - AlignmentOutput, - ElementOutput, - LayoutOutput, - AlignmentOutput, - EpilogueSchedule>::CollectiveOp; - - using CollectiveMainloop = - typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, - OperatorClass, - cute::tuple, - LayoutInputB, - AlignmentInputB, - ElementInputA, - LayoutInputA, - AlignmentInputA, - ElementAccumulator, - TileShape, - ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout( - sizeof(typename CollectiveEpilogue::SharedStorage))>, - MainLoopSchedule>::CollectiveOp; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - cute::Shape, - CollectiveMainloop, - CollectiveEpilogue>; - - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - - using StrideInputA = typename Gemm::GemmKernel::StrideA; - using StrideInputB = typename Gemm::GemmKernel::StrideB; - using StrideOutput = typename Gemm::GemmKernel::StrideC; - using StrideS = typename CollectiveMainloop::StrideScale; - - StrideInputA stride_a = cutlass::make_cute_packed_stride( - StrideInputA{}, cute::make_shape(M, K, 1)); - StrideInputB stride_b = cutlass::make_cute_packed_stride( - StrideInputB{}, cute::make_shape(N, K, 1)); - StrideOutput stride_output = cutlass::make_cute_packed_stride( - StrideOutput{}, cute::make_shape(N, M, 1)); - StrideS stride_S = cutlass::make_cute_packed_stride( - StrideS{}, cute::make_shape(N, num_groups, 1)); - - typename Gemm::Arguments arguments{ - cutlass::gemm::GemmUniversalMode::kGemm, - {N, M, K}, - {reinterpret_cast(WQ.data_ptr()), - stride_b, - reinterpret_cast(X.data_ptr()), - stride_a, - reinterpret_cast(w_scale.data_ptr()), - stride_S, - group_size, - reinterpret_cast(w_zp.data_ptr())}, - {{1.0, 0.0}, - (ElementOutput*)Y.data_ptr(), - stride_output, - (ElementOutput*)Y.data_ptr(), - stride_output}}; - - Gemm gemm; - - // Using the arguments, query for extra workspace required for matrix - // multiplication computation - size_t workspace_size = Gemm::get_workspace_size(arguments); - - // Allocate workspace memory - cutlass::device_memory::allocation workspace(workspace_size); - - // Check the problem size is supported or not - cutlass::Status status = gemm.can_implement(arguments); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot implement"); - } - - // Initialize CUTLASS kernel with arguments and workspace pointer - status = gemm.initialize(arguments, workspace.get()); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot initialize"); - } - - status = gemm(at::cuda::getCurrentCUDAStream()); - - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error( - std::string("cutlass cannot run") + - cutlass::cutlassGetStatusString(status)); - } - C10_CUDA_KERNEL_LAUNCH_CHECK(); - - return Y; -} - -template -at::Tensor dispatch_bf16i4bf16_rowwise_kernel( - at::Tensor X, // BF16 - at::Tensor WQ, // INT4 - at::Tensor w_scale, - at::Tensor w_zp) { - KernelMode kernel = get_kernel_mode(X, WQ); - if (kernel == KernelMode::Small) { - return bf16i4bf16_rowwise_impl< - 64, - 128, - 128, - 2, - 1, - 1, - true, - WEIGHT_SCALE_DTYPE>(X, WQ, w_scale, w_zp); - } else if (kernel == KernelMode::Large) { - return bf16i4bf16_rowwise_impl< - 128, - 256, - 64, - 2, - 1, - 1, - false, - WEIGHT_SCALE_DTYPE>(X, WQ, w_scale, w_zp); - } else { - return bf16i4bf16_rowwise_impl< - 128, - 256, - 64, - 2, - 1, - 1, - false, - WEIGHT_SCALE_DTYPE>(X, WQ, w_scale, w_zp); - } -} - -at::Tensor bf16i4bf16_rowwise( - at::Tensor X, // BF16 - at::Tensor WQ, // INT4 - at::Tensor w_scale, - at::Tensor w_zp) { - // Check datatypes. - TORCH_CHECK( - (w_scale.dtype() == at::kFloat && w_zp.dtype() == at::kFloat) || - (w_scale.dtype() == at::kHalf && w_zp.dtype() == at::kHalf) || - (w_scale.dtype() == at::kBFloat16 && w_zp.dtype() == at::kBFloat16), - "Weight scale and zero point tensors must be float32, bfloat16, or float16, and dtype of weight scale and zero point tensors must be the same ."); - - if (w_scale.dtype() == at::kFloat) { - return dispatch_bf16i4bf16_rowwise_kernel(X, WQ, w_scale, w_zp); - } else if (w_scale.dtype() == at::kHalf) { - return dispatch_bf16i4bf16_rowwise_kernel( - X, WQ, w_scale, w_zp); - } else if (w_scale.dtype() == at::kBFloat16) { - return dispatch_bf16i4bf16_rowwise_kernel( - X, WQ, w_scale, w_zp); - } else { - throw std::runtime_error( - "Weight scale and zero point data type not supported in bf16i4bf16_rowwise"); - } -} - -#else - -at::Tensor bf16i4bf16_rowwise( - at::Tensor X, // BF16 - at::Tensor WQ, // INT4 - at::Tensor w_scale, - at::Tensor w_zp) { - throw std::runtime_error( - "CUDA version is older than 12.0"); // requires CUDA>=12 -} - -#endif - -} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8i4bf16_shuffled.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8i4bf16_shuffled.cu index 0e459d8b09..d3ccc5c059 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8i4bf16_shuffled.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8i4bf16_shuffled.cu @@ -146,7 +146,7 @@ at::Tensor _f8i4bf16_shuffled( EpilogueTileType, ElementAccumulator, ElementAccumulator, - ElementC, + void, typename cutlass::layout::LayoutTranspose::type, AlignmentC, ElementC, @@ -207,7 +207,7 @@ at::Tensor _f8i4bf16_shuffled( stride_S, group_size}, {{}, - reinterpret_cast(Y.data_ptr()), + nullptr, stride_C, reinterpret_cast(Y.data_ptr()), stride_C}}; diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mixed_dtype_utils.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mixed_dtype_utils.cu index f741c1e259..ee6882687b 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mixed_dtype_utils.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mixed_dtype_utils.cu @@ -17,15 +17,9 @@ namespace fbgemm_gpu { -std::tuple preshuffle_i4( +std::tuple fp8_preshuffle_i4( at::Tensor WQ, at::Tensor w_scale) { - // Check that w_scale is proper type. if not, quantize it. - if (w_scale.dtype() != at::kFloat8_e4m3fn) { - TORCH_WARN( - "Weight scale must be FP8 for preshuffled GEMM. Performing downcasting."); - w_scale = w_scale.to(WQ.options().dtype(at::kFloat8_e4m3fn)); - } // Start by allocating space for shuffled tensors. at::Tensor WQ_shuffled = at::empty_like(WQ); // Packed scale contains 8 lookup values for each original scale element. @@ -73,4 +67,55 @@ std::tuple preshuffle_i4( return {WQ_shuffled, w_scale_packed}; } +std::tuple bf16_preshuffle_i4( + at::Tensor WQ, + at::Tensor w_scale) { + // For bf16 we only preshuffle the weight tensor, scales arent modified. + // Next we need to shuffle B. To do this, we define a few helper objects. + const int N = WQ.size(0); + const int K = 2 * WQ.size(1); + auto shape_B = cute::make_shape(N, K, 1); + using LayoutB = cutlass::layout::ColumnMajor; + using StrideB = cutlass::detail::TagToStrideB_t; + using ValueShuffle = cute::Layout< + cute::Shape, + cute::Stride>; // order [0,2,4,6,1,3,5,7] + int constexpr NumShuffleAtoms = 1; + using MmaAtomShape = + cute::Layout>>; + using LayoutAtomQuant = decltype(cutlass::compute_memory_reordering_atom< + cutlass::bfloat16_t, + MmaAtomShape, + ValueShuffle>()); + using LayoutB_Reordered = decltype(cute::tile_to_shape( + LayoutAtomQuant{}, cute::Layout, StrideB>{})); + StrideB stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B); + auto layout_B = make_layout(shape_B, stride_B); + LayoutB_Reordered layout_B_reordered = + cute::tile_to_shape(LayoutAtomQuant{}, shape_B); + ; + + // Now we're ready to reorder the tensor into proper layout. + cutlass::reorder_tensor( + reinterpret_cast(WQ.data_ptr()), + layout_B, + layout_B_reordered); + + // Tensors should now be preshuffled and ready for use. + return {WQ, w_scale}; +} + +std::tuple preshuffle_i4( + at::Tensor WQ, + at::Tensor w_scale) { + TORCH_CHECK( + w_scale.dtype() == at::kFloat8_e4m3fn || w_scale.dtype() == at::kBFloat16, + "Activation type must be FP8 or BF16."); + if (w_scale.dtype() == at::kFloat8_e4m3fn) { + return fp8_preshuffle_i4(WQ, w_scale); + } else { + return bf16_preshuffle_i4(WQ, w_scale); + } +} + } // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp index 07051313c1..6184e50485 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp @@ -148,6 +148,11 @@ at::Tensor f8i4bf16_shuffled( at::Tensor x_scale, at::Tensor w_scale, at::Tensor w_scale_group); +at::Tensor bf16i4bf16_shuffled( + at::Tensor X, + at::Tensor W, + at::Tensor w_scale_group, + at::Tensor w_zero_group); at::Tensor f8i4bf16_shuffled_grouped( at::Tensor XQ, at::Tensor WQ, @@ -160,9 +165,9 @@ std::tuple preshuffle_i4( at::Tensor w_scale); at::Tensor bf16i4bf16_rowwise( at::Tensor X, - at::Tensor WQ, - at::Tensor w_scale, - at::Tensor w_zp); + at::Tensor W, + at::Tensor w_scale_group, + at::Tensor w_zero_group); at::Tensor bf16i4bf16_rowwise_batched( at::Tensor X, at::Tensor WQ, @@ -217,6 +222,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { "f8i4bf16_rowwise(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor w_zp) -> Tensor"); m.def( "f8i4bf16_shuffled(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor w_scale_group) -> Tensor"); + m.def( + "bf16i4bf16_shuffled(Tensor X, Tensor W, Tensor w_scale_group, Tensor w_zero_group) -> Tensor"); m.def( "f8i4bf16_shuffled_grouped(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor w_scale_group, Tensor M_sizes) -> Tensor"); m.impl("f8i4bf16_shuffled", f8i4bf16_shuffled); @@ -227,7 +234,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { "fp8fp8bf16_fast_gemv(Tensor X, Tensor W, Tensor x_scale, Tensor w_scale) -> Tensor"); m.def("f8f8bf16_lite(Tensor XQ, Tensor WQ, Tensor scale) -> Tensor"); m.def( - "bf16i4bf16_rowwise(Tensor X, Tensor WQ, Tensor w_scale, Tensor w_zp) -> Tensor"); + "bf16i4bf16_rowwise(Tensor X, Tensor W, Tensor w_scale_group, Tensor w_zero_group) -> Tensor"); m.def( "bf16i4bf16_rowwise_batched(Tensor X, Tensor WQ, Tensor w_scale, Tensor w_zp) -> Tensor"); m.def( @@ -318,6 +325,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { m.impl("f8f8bf16_lite", f8f8bf16_lite); m.impl("f8i4bf16_rowwise", f8i4bf16_rowwise); m.impl("f8i4bf16_shuffled", f8i4bf16_shuffled); + m.impl("bf16i4bf16_shuffled", bf16i4bf16_shuffled); m.impl("f8i4bf16_shuffled_grouped", f8i4bf16_shuffled_grouped); m.impl("preshuffle_i4", preshuffle_i4); m.impl("bf16i4bf16_rowwise_batched", bf16i4bf16_rowwise_batched); @@ -352,6 +360,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) { m.impl("f8f8bf16_lite", f8f8bf16_lite); m.impl("f8i4bf16_rowwise", f8i4bf16_rowwise); m.impl("f8i4bf16_shuffled", f8i4bf16_shuffled); + m.impl("bf16i4bf16_shuffled", bf16i4bf16_shuffled); m.impl("f8i4bf16_shuffled_grouped", f8i4bf16_shuffled_grouped); m.impl("preshuffle_i4", preshuffle_i4); m.impl("bf16i4bf16_rowwise_batched", bf16i4bf16_rowwise_batched); @@ -518,12 +527,12 @@ at::Tensor f8i4bf16_rowwise_meta( at::Tensor bf16i4bf16_rowwise_meta( at::Tensor X, // BF16 - at::Tensor WQ, // INT4 - at::Tensor /* w_scale */, - at::Tensor /* w_zp */ + at::Tensor W, // INT4 + at::Tensor /* w_scale_group */, + at::Tensor /* w_zero_group */ ) { int M = X.size(0); - int N = WQ.size(0); + int N = W.size(0); auto Y = at::empty({M, N}, X.options().dtype(at::kBFloat16)); return Y; } diff --git a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py index 5e36ca401d..5a30977b1a 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py @@ -470,7 +470,7 @@ def test_quantize_int4_fp8_matmul( w_zp = w_zp.contiguous().to(device="cuda") # Preshuffled i4 weight format. - wq_shuffled, w_scale_row, w_scale_group = quantize_int4_preshuffle(w, 128) + wq_shuffled, (w_scale_group, w_scale_row) = quantize_int4_preshuffle(w, 128) if CudaGraph: g = torch.cuda.CUDAGraph()