Skip to content

Conversation

badeok0716
Copy link

Pull request for issue (#19).

Contents

I implemented a new linear module bitshift_linear_kernel. By setting use_prev_kernel flag to False, QuantizedLinear module utilizes the new module for the forward pass. I modified fine-tuning codes to utilize the new module when the --ft_train_lut flag is set to True. Please review the recent commit for details.

Test

I ran the following script for the minimal unit test:

from lib.utils.unsafe_import import model_from_hf_path
import torch

def get_quantized_layer(path2qmodel="/path2hf/2_7b_2bit"): 
    # load a quantized model
    quant_model = model_from_hf_path(path2qmodel)[0].float()
    # select an arbitrary layer.
    quantized_layer = quant_model.model.layers[0].self_attn.q_proj

    # replicate the routine in finetune_e2e_llama.py #L95:107 with --ft_train_lut flag 
    quantized_layer.SU = torch.nn.Parameter(quantized_layer.SU.float(), requires_grad=True)
    quantized_layer.SV = torch.nn.Parameter(quantized_layer.SV.float(), requires_grad=True)
    quantized_layer.mode = "train-recons"
    quantized_layer.tlut.requires_grad = True

    return quantized_layer

def reinit_grad(quantized_layer):
    quantized_layer.SU.grad = None
    quantized_layer.SV.grad = None
    quantized_layer.tlut.grad = None

def test_backward():
    # load quantized layer
    quantized_layer = get_quantized_layer()

    # initialize random input to the layer
    ft_bs, ctx_size, in_features = 4, 4096, 4096
    input = torch.randn(ft_bs, ctx_size, in_features).to('cuda').to(torch.float16)
    input.requires_grad = True

    # prev forward pass
    quantized_layer.use_prev_kernel = True # flag for new forward pass
    output = quantized_layer(input)

    # backward pass
    print("=== Prev kernel ===")
    loss = output.sum()
    loss.backward()

    print("loss", loss.item())
    print("input grad", input.grad.shape)
    print("SU grad", quantized_layer.SU.grad.shape)
    print("SV grad", quantized_layer.SV.grad.shape)
    print("tlut grad", quantized_layer.tlut.grad)

    prev_input_grad, prev_SU_grad, prev_SV_grad = input.grad.clone(), quantized_layer.SU.grad.clone(), quantized_layer.SV.grad.clone()
    prev_loss = loss.clone()
    reinit_grad(quantized_layer)
    input.grad = None

    # new forward pass
    quantized_layer.use_prev_kernel = False # flag for new forward pass
    output = quantized_layer(input)

    # backward pass
    print("=== New kernel ===")
    loss = output.sum()
    loss.backward()

    print("loss", loss.item())
    if torch.abs(prev_loss.item() - loss.item()) <= 1e-8:
        print("\tEqual loss")
    print("input grad", input.grad.shape)
    if torch.allclose(prev_input_grad, input.grad):
        print("\tEqual input grad")
    print("SU grad", quantized_layer.SU.grad.shape)
    if torch.allclose(prev_SU_grad, quantized_layer.SU.grad):
        print("\tEqual SU grad")
    print("SV grad", quantized_layer.SV.grad.shape)
    if torch.allclose(prev_SV_grad, quantized_layer.SV.grad):
        print("\tEqual SV grad")
    print("tlut grad", quantized_layer.tlut.grad)

if __name__ == "__main__":
    test_backward()

which outputs

=== Prev kernel ===
loss -7860.0
input grad torch.Size([4, 4096, 4096])
SU grad torch.Size([4096])
SV grad torch.Size([4096])
tlut grad None
=== New kernel ===
loss -7860.0
        Equivalent loss
input grad torch.Size([4, 4096, 4096])
        Equivalent input grad
SU grad torch.Size([4096])
        Equivalent SU grad
SV grad torch.Size([4096])
        Equivalent SV grad
tlut grad tensor([[ 407.4847,    4.3954],
        [ 493.1924,   49.0028],
        [-140.4716, -250.8964],
        ...,
        [ 437.4753,  151.0592],
        [ 210.5741, -317.1626],
        [ 248.4033,  -28.8121]], device='cuda:0')      

Results

For the previous implementation (which do not update codebook), I could reproduce the perplexity of QTIP in the paper by running the example script in README.md.
After introducing the updated module, I ran the same script (i.e., using the same training recipe) and achieved similar performance, as summarized in the table below. Since the updated implementation involves different trainable parameters compared to the previous version, it is reasonable to expect further efforts to identify the optimal training recipe and devset size. However, I will defer these optimizations to other contributors.

-------- Llama-2-7b-hf (2bit, HYB, fine-tune) ----------
wiki c4
qtip (paper) 5.86 7.73
qtip (reprod) 5.88 7.73
qtip (new kernel) 5.92 7.72

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant