From 1f7eaa3cd2726cfbdea796125c835788b13b7d67 Mon Sep 17 00:00:00 2001 From: badeok0716 Date: Thu, 9 Jan 2025 14:39:36 +0900 Subject: [PATCH] add differentiable kernel --- lib/algo/finetune.py | 1 + lib/codebook/bitshift.py | 14 +++++--- lib/linear/quantized_linear.py | 5 ++- lib/utils/kernel_decompress.py | 50 +++++++++++++++++++++++----- quantize_llama/finetune_e2e_llama.py | 1 + 5 files changed, 58 insertions(+), 13 deletions(-) diff --git a/lib/algo/finetune.py b/lib/algo/finetune.py index 5847eb3..41ef496 100644 --- a/lib/algo/finetune.py +++ b/lib/algo/finetune.py @@ -271,6 +271,7 @@ def quantize_finetune_decoder_layer(mixed_layer, quant_order, idx, cb, args, args.tlut_bits, args.decode_mode, mode='train-recons' if args.ft_train_lut else 'train-fixW', + use_prev_kernel=not args.ft_train_lut, dtype=orig_dtype, grad_ckpt=args.ft_grad_ckpt) q_linear.trellis.copy_(packed) diff --git a/lib/codebook/bitshift.py b/lib/codebook/bitshift.py index c20f225..e09a4a9 100644 --- a/lib/codebook/bitshift.py +++ b/lib/codebook/bitshift.py @@ -10,7 +10,7 @@ from lib.codebook import kdict from lib.utils.kernel_check import has_kernel -from lib.utils.kernel_decompress import decode_compressed +from lib.utils.kernel_decompress import decode_compressed, bitshift_linear_kernel from lib.utils.matmul_had import matmul_hadU_cuda, matmul_hadUt_cuda @@ -424,6 +424,7 @@ def forward(self, rcp, tp_rank, mode='eval', + use_prev_kernel=False, **kwargs): n, m = len(SU), len(SV) x = input.view(-1, n).to(torch.float32) @@ -452,9 +453,14 @@ def forward(self, self.cb.recons_lut() if self.has_kernel: - x = BitshiftLinearKernelAG.apply( - x, trellis, m, n, self.cb.L, self.cb.tlut_bits, self.cb.K, - self.V, self.cb.lut).float() + if use_prev_kernel: + x = BitshiftLinearKernelAG.apply( + x, trellis, m, n, self.cb.L, self.cb.tlut_bits, self.cb.K, + self.V, self.cb.lut).float() + else: + x = bitshift_linear_kernel( + x, trellis, m, n, self.cb.L, self.cb.tlut_bits, self.cb.K, + self.V, self.cb.lut).float() else: if mode == 'eval': trellis = self.cb.unpack_trellis( diff --git a/lib/linear/quantized_linear.py b/lib/linear/quantized_linear.py index a917f4b..19069f6 100644 --- a/lib/linear/quantized_linear.py +++ b/lib/linear/quantized_linear.py @@ -25,6 +25,7 @@ def __init__( bias=False, dtype=torch.float16, mode='eval', + use_prev_kernel=True, grad_ckpt=False, ): super().__init__() @@ -76,6 +77,7 @@ def __init__( self.K_left = K_left self.K_right = K_right self.mode = mode + self.use_prev_kernel = use_prev_kernel self.grad_ckpt = grad_ckpt self.has_kernel = has_kernel(decode_mode, L, K, V, tlut_bits, td_x, td_y) @@ -145,7 +147,8 @@ def no_ckpt_forward(self, input): self.K_right, self.rcp, self.tp_rank, - mode=self.mode) + 0 + mode=self.mode, + use_prev_kernel=self.use_prev_kernel) + 0 if self.bias is not None: return result + self.bias return result diff --git a/lib/utils/kernel_decompress.py b/lib/utils/kernel_decompress.py index cece4da..76564c6 100644 --- a/lib/utils/kernel_decompress.py +++ b/lib/utils/kernel_decompress.py @@ -1,8 +1,21 @@ import torch - +import math @torch.compile def decode_compressed(L, S, R, V, m, k, compressed, expanded_lut): + indices = decode_indices(L, R, V, m, k, compressed) + + # decode lut + mma_swizzled = expanded_lut[indices] + + # deswizzle m16n8k16 mma pattern + decompressed = (mma_swizzled.reshape(m // 16, k // 16, 16, 16).reshape( + m // 16, k // 16, 8, 4, 2, 2, 2).permute(0, -2, 2, 1, -3, 3, + -1).reshape(m, k)) + return decompressed + +@torch.compile +def decode_indices(L, R, V, m, k, compressed): if compressed.dtype != torch.uint16: compressed = compressed.view(torch.uint16) @@ -44,12 +57,33 @@ def decode_compressed(L, S, R, V, m, k, compressed, expanded_lut): shifted = expanded32 >> (16 - shifts) indices = torch.bitwise_and( shifted.reshape(shifted.shape[0], -1)[:, 16 - L::R << V], (1 << L) - 1) + + return indices - # decode lut - mma_swizzled = expanded_lut[indices] - # deswizzle m16n8k16 mma pattern - decompressed = (mma_swizzled.reshape(m // 16, k // 16, 16, 16).reshape( - m // 16, k // 16, 8, 4, 2, 2, 2).permute(0, -2, 2, 1, -3, 3, - -1).reshape(m, k)) - return decompressed +# for ensuring non-diffentiablity +class NonDiffDecode(torch.autograd.Function): + @staticmethod + def forward(ctx, trellis, m, n, L, K, V): + ctx.L = L + ctx.K = K + ctx.V = V + ctx.m = m + ctx.n = n + + indices = decode_indices(L, K, int(math.log2(V)), + m, n, trellis.view(-1)) + return indices + + @staticmethod + def backward(ctx, grad_output): + return None, None, None, None, None, None + +def bitshift_linear_kernel(input, trellis, m, n, L, tlut_bits, K, V, lut): + indices = NonDiffDecode.apply(trellis, m, n, L, K, V) + mma_swizzled = lut.T[indices] + + hatW = (mma_swizzled.reshape(m // 16, n // 16, 16, 16).reshape( + m // 16, n // 16, 8, 4, 2, 2, 2).permute(0, -2, 2, 1, -3, 3, + -1).reshape(m, n)) + return input.to(hatW.dtype) @ hatW.T \ No newline at end of file diff --git a/quantize_llama/finetune_e2e_llama.py b/quantize_llama/finetune_e2e_llama.py index 837da98..7cf1674 100644 --- a/quantize_llama/finetune_e2e_llama.py +++ b/quantize_llama/finetune_e2e_llama.py @@ -98,6 +98,7 @@ def main(args): if module.tlut is not None and args.ft_train_lut: module.tlut.requires_grad = True if args.ft_train_lut: + module.use_prev_kernel = False module.mode = 'train-recons' glog.info('overriding ft_prefetch_trellis') elif args.ft_prefetch_trellis: