Skip to content

Preshuffled BF16I4 Gemm Kernel #3913

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 52 additions & 3 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1426,7 +1426,7 @@ def cuda(self) -> bool:
class F8I4ShuffledGemm(QuantizeOpBase):
def preprocess(self, x, w):
# Prequantize and pack weights.
wq, row_scale, group_scale = quantize_int4_preshuffle(w)
wq, (group_scale, row_scale) = quantize_int4_preshuffle(w)
return x, wq, row_scale, group_scale

def quantize(self, x, wq, row_scale, group_scale):
Expand Down Expand Up @@ -1470,6 +1470,49 @@ def cuda(self) -> bool:
return True


@register_quantize_op
class BF16I4ShuffledGemm(QuantizeOpBase):
def preprocess(self, x, w):
# Prequantize and pack weights.
wq, (group_scale, group_zero) = quantize_int4_preshuffle(w, dtype="bf16")
return x, wq, group_scale, group_zero

def quantize(self, x, wq, group_scale, group_zero):
# No extra action required.
return x, wq, group_scale, group_zero

def compute(self, x, wq, group_scale, group_zero):
# Handle batched cases by looping over each batch.
if x.dim() == 3:
B, M, _ = x.shape
_, N, _ = wq.shape
y = torch.empty((B, M, N), device=x.device, dtype=torch.bfloat16)
for i in range(B):
y[i] = torch.ops.fbgemm.bf16i4bf16_shuffled(
x[i], wq[i], group_scale[i], group_zero[i]
)
return y
# Otherwise run gemm normally.
return torch.ops.fbgemm.bf16i4bf16_shuffled(x, wq, group_scale, group_zero)

def quantize_and_compute(self, x, wq, group_scale, group_zero):
x, wq, group_scale, group_zero = self.quantize(x, wq, group_scale, group_zero)
return self.compute(x, wq, group_scale, group_zero)

@property
def name(self) -> str:
return "cutlass_bf16i4_preshuffle"

@property
def hip(self) -> bool:
# Not yet supported on AMD.
return False

@property
def cuda(self) -> bool:
return True


@register_quantize_op
class F8I4ShuffledGroupedGemm(QuantizeOpBase):
"""
Expand All @@ -1485,7 +1528,8 @@ def preprocess(self, x, w):
m_sizes = torch.tensor(m_values).to(dtype=torch.int32, device=x[0].device)
# Quantize weights.
# TODO Only rowwise scaling is currently supported. This needs to be fixed.
wq, row_scale, group_scale = zip(*[quantize_int4_preshuffle(i) for i in w])
wq, scales = zip(*[quantize_int4_preshuffle(i) for i in w])
group_scale, row_scale = zip(*scales)
# Group weights as single tensor.
wq = torch.stack(wq, dim=0).contiguous()
row_scale = torch.stack(row_scale, dim=0).contiguous()
Expand Down Expand Up @@ -1580,7 +1624,12 @@ def quantize(self, x, w):
wq, w_scale, w_zp = self._int4_row_quantize(w)
# Pack int4 values together.
wq = self._pack_int4(wq)
return x.to(torch.bfloat16), wq, w_scale, w_zp
return (
x.to(torch.bfloat16),
wq,
w_scale,
w_zp,
)

def compute(self, x, wq, w_scale, w_zp):
return torch.ops.fbgemm.bf16i4bf16_rowwise(x, wq, w_scale, w_zp)
Expand Down
108 changes: 73 additions & 35 deletions fbgemm_gpu/experimental/gen_ai/gen_ai/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,34 @@ def pack_int4(x: torch.Tensor) -> torch.Tensor:
return torch.bitwise_or(low_x, high_x).contiguous()


def int4_row_quantize_zp(
x: torch.Tensor,
group_size: int = 128,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
n_bit = 4 # Number of target bits.
to_quant = x.reshape(-1, group_size).to(torch.float)

max_val = to_quant.amax(dim=1, keepdim=True)
min_val = to_quant.amin(dim=1, keepdim=True)
max_int = 2**n_bit - 1
min_int = 0
scales = (max_val - min_val).clamp(min=1e-6) / max_int

zeros = min_val + scales * (2 ** (n_bit - 1))

out = to_quant.sub(min_val).div(scales).round().clamp_(min_int, max_int)

# Recenter output and move to int8.
out = (out - 2 ** (n_bit - 1)).to(dtype=torch.int8).reshape(x.shape)

# Cutlass expects column major layout for scale and zero point,
# so we transpose here and make them contiguous.
scales = scales.view(x.shape[0], -1).t().contiguous()
zeros = zeros.view(x.shape[0], -1).t().contiguous()

return out, scales, zeros


def int4_row_quantize(
x: torch.Tensor,
group_size: int = 128,
Expand Down Expand Up @@ -63,8 +91,8 @@ def int4_row_quantize(


def quantize_int4_preshuffle(
w: torch.Tensor, group_size: int = 128
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
w: torch.Tensor, group_size: int = 128, dtype: str = "fp8"
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Quantizes an input weight tensor to int4 using preshuffling and scale packing.
This function is intended to be used with fbgemms mixed dtype kernels and is expected
Expand All @@ -73,47 +101,57 @@ def quantize_int4_preshuffle(
Args:
w (Tensor): [N, K] Higher precision weight tensor to quantize. May optionally have a batch dimension.
group_size (int): Number of elements to calculate group scale for, must be at least 128.
dtype (torch.dtype): Type of corresponding activations. Must be fp8 or bf16.
Returns:
wq (Tensor): [N, K // 2] Quantized int4 weight tensor packed into int8 elements.
row_scale (Tensor): [N] FP32 Scale per row of the weight tensor.
group_scale (Tensor): [K / group_size, 8, N] FP8 Scale per group of the weight tensor.
scales (Tuple[Tensor]): Scale tensors for the specified activation type. When FP8 is used,
scales is a tuple of row_scale ([N]) and group_scale ([K / group_size, 8, N]). When BF16 is
used, scales is a tuple of group_scale([K / group_size, N]) and group_zero ([K / group_size, N])
"""

def _quantize(w: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# Start by lowering weights to FP8 and producing row scales.
wq, row_scale = quantize_fp8_row(w)

# Now reduce to INT4.
wq, group_scale = int4_row_quantize(wq, group_size)
# Reduce group scale to FP8.
group_scale = group_scale.to(torch.float8_e4m3fn)

# Take quantized weights and pack them efficiently.
wq = pack_int4(wq)

# Finally pack weights and scales into efficient preshuffled format.
wq, group_scale = torch.ops.fbgemm.preshuffle_i4(wq, group_scale)

return wq, row_scale, group_scale
def _quantize(
w: torch.Tensor, dtype: str = "fp8"
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:

if dtype == "fp8":
# Start by lowering weights to FP8 and producing row scales.
wq, row_scale = quantize_fp8_row(w)

# Now reduce to INT4.
wq, group_scale = int4_row_quantize(wq, group_size)
# Reduce group scale to FP8.
group_scale = group_scale.to(torch.float8_e4m3fn)
# Take quantized weights and pack them efficiently.
wq = pack_int4(wq)
# Finally pack weights and scales into efficient preshuffled format.
wq, group_scale = torch.ops.fbgemm.preshuffle_i4(wq, group_scale)
return wq, (group_scale, row_scale)

elif dtype == "bf16":
wq, group_scale, group_zero = int4_row_quantize_zp(w, group_size)
# Set scales to activation type.
group_scale = group_scale.to(torch.bfloat16)
group_zero = group_zero.to(torch.bfloat16)
# Take quantized weights and pack them efficiently.
wq = pack_int4(wq)
# Finally pack weights and scales into efficient preshuffled format.
wq, group_scale = torch.ops.fbgemm.preshuffle_i4(wq, group_scale)
return wq, (group_scale, group_zero)
else:
raise NotImplementedError("Only fp8 and bf16 activations supported.")

if w.ndim >= 3:
orig_shape = w.shape
# Flatten to 3 dimensions then iterate over batches.
w = w.view(-1, *w.shape[1:])
w.unbind(dim=0)
wq = []
row_scale = []
group_scale = []
for batch in w:
wq_, row_scale_, group_scale_ = _quantize(batch)
wq.append(wq_)
row_scale.append(row_scale_)
group_scale.append(group_scale_)
wq, scales = zip(*[_quantize(i, dtype=dtype) for i in w])
wq = torch.stack(wq).view(*orig_shape[:-2], *wq[0].shape)
row_scale = torch.stack(row_scale).view(*orig_shape[:-2], *row_scale[0].shape)
group_scale = torch.stack(group_scale).view(
*orig_shape[:-2], *group_scale[0].shape
# Decompose then stack scales back into a tuple.
a_scales, b_scales = zip(*scales)
scales = (
torch.stack(a_scales).view(*orig_shape[:-2], *a_scales[0].shape),
torch.stack(b_scales).view(*orig_shape[:-2], *b_scales[0].shape),
)
else:
wq, row_scale, group_scale = _quantize(w)
return wq, row_scale, group_scale
wq, scales = _quantize(w, dtype=dtype)

return wq, scales
Loading
Loading