Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lib/algo/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ def quantize_finetune_decoder_layer(mixed_layer, quant_order, idx, cb, args,
args.tlut_bits,
args.decode_mode,
mode='train-recons' if args.ft_train_lut else 'train-fixW',
use_prev_kernel=not args.ft_train_lut,
dtype=orig_dtype,
grad_ckpt=args.ft_grad_ckpt)
q_linear.trellis.copy_(packed)
Expand Down
14 changes: 10 additions & 4 deletions lib/codebook/bitshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from lib.codebook import kdict
from lib.utils.kernel_check import has_kernel
from lib.utils.kernel_decompress import decode_compressed
from lib.utils.kernel_decompress import decode_compressed, bitshift_linear_kernel
from lib.utils.matmul_had import matmul_hadU_cuda, matmul_hadUt_cuda


Expand Down Expand Up @@ -424,6 +424,7 @@ def forward(self,
rcp,
tp_rank,
mode='eval',
use_prev_kernel=False,
**kwargs):
n, m = len(SU), len(SV)
x = input.view(-1, n).to(torch.float32)
Expand Down Expand Up @@ -452,9 +453,14 @@ def forward(self,
self.cb.recons_lut()

if self.has_kernel:
x = BitshiftLinearKernelAG.apply(
x, trellis, m, n, self.cb.L, self.cb.tlut_bits, self.cb.K,
self.V, self.cb.lut).float()
if use_prev_kernel:
x = BitshiftLinearKernelAG.apply(
x, trellis, m, n, self.cb.L, self.cb.tlut_bits, self.cb.K,
self.V, self.cb.lut).float()
else:
x = bitshift_linear_kernel(
x, trellis, m, n, self.cb.L, self.cb.tlut_bits, self.cb.K,
self.V, self.cb.lut).float()
else:
if mode == 'eval':
trellis = self.cb.unpack_trellis(
Expand Down
5 changes: 4 additions & 1 deletion lib/linear/quantized_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(
bias=False,
dtype=torch.float16,
mode='eval',
use_prev_kernel=True,
grad_ckpt=False,
):
super().__init__()
Expand Down Expand Up @@ -76,6 +77,7 @@ def __init__(
self.K_left = K_left
self.K_right = K_right
self.mode = mode
self.use_prev_kernel = use_prev_kernel
self.grad_ckpt = grad_ckpt
self.has_kernel = has_kernel(decode_mode, L, K, V, tlut_bits, td_x,
td_y)
Expand Down Expand Up @@ -145,7 +147,8 @@ def no_ckpt_forward(self, input):
self.K_right,
self.rcp,
self.tp_rank,
mode=self.mode) + 0
mode=self.mode,
use_prev_kernel=self.use_prev_kernel) + 0
if self.bias is not None:
return result + self.bias
return result
50 changes: 42 additions & 8 deletions lib/utils/kernel_decompress.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,21 @@
import torch

import math

@torch.compile
def decode_compressed(L, S, R, V, m, k, compressed, expanded_lut):
indices = decode_indices(L, R, V, m, k, compressed)

# decode lut
mma_swizzled = expanded_lut[indices]

# deswizzle m16n8k16 mma pattern
decompressed = (mma_swizzled.reshape(m // 16, k // 16, 16, 16).reshape(
m // 16, k // 16, 8, 4, 2, 2, 2).permute(0, -2, 2, 1, -3, 3,
-1).reshape(m, k))
return decompressed

@torch.compile
def decode_indices(L, R, V, m, k, compressed):
if compressed.dtype != torch.uint16:
compressed = compressed.view(torch.uint16)

Expand Down Expand Up @@ -44,12 +57,33 @@ def decode_compressed(L, S, R, V, m, k, compressed, expanded_lut):
shifted = expanded32 >> (16 - shifts)
indices = torch.bitwise_and(
shifted.reshape(shifted.shape[0], -1)[:, 16 - L::R << V], (1 << L) - 1)

return indices

# decode lut
mma_swizzled = expanded_lut[indices]

# deswizzle m16n8k16 mma pattern
decompressed = (mma_swizzled.reshape(m // 16, k // 16, 16, 16).reshape(
m // 16, k // 16, 8, 4, 2, 2, 2).permute(0, -2, 2, 1, -3, 3,
-1).reshape(m, k))
return decompressed
# for ensuring non-diffentiablity
class NonDiffDecode(torch.autograd.Function):
@staticmethod
def forward(ctx, trellis, m, n, L, K, V):
ctx.L = L
ctx.K = K
ctx.V = V
ctx.m = m
ctx.n = n

indices = decode_indices(L, K, int(math.log2(V)),
m, n, trellis.view(-1))
return indices

@staticmethod
def backward(ctx, grad_output):
return None, None, None, None, None, None

def bitshift_linear_kernel(input, trellis, m, n, L, tlut_bits, K, V, lut):
indices = NonDiffDecode.apply(trellis, m, n, L, K, V)
mma_swizzled = lut.T[indices]

hatW = (mma_swizzled.reshape(m // 16, n // 16, 16, 16).reshape(
m // 16, n // 16, 8, 4, 2, 2, 2).permute(0, -2, 2, 1, -3, 3,
-1).reshape(m, n))
return input.to(hatW.dtype) @ hatW.T
1 change: 1 addition & 0 deletions quantize_llama/finetune_e2e_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def main(args):
if module.tlut is not None and args.ft_train_lut:
module.tlut.requires_grad = True
if args.ft_train_lut:
module.use_prev_kernel = False
module.mode = 'train-recons'
glog.info('overriding ft_prefetch_trellis')
elif args.ft_prefetch_trellis:
Expand Down