add a differentiable linear module for HYB codebook update. #20
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Pull request for issue (#19).
Contents
I implemented a new linear module bitshift_linear_kernel. By setting use_prev_kernel flag to
False
, QuantizedLinear module utilizes the new module for the forward pass. I modified fine-tuning codes to utilize the new module when the--ft_train_lut
flag is set toTrue
. Please review the recent commit for details.Test
I ran the following script for the minimal unit test:
which outputs
Results
For the previous implementation (which do not update codebook), I could reproduce the perplexity of QTIP in the paper by running the example script in README.md.
After introducing the updated module, I ran the same script (i.e., using the same training recipe) and achieved similar performance, as summarized in the table below. Since the updated implementation involves different trainable parameters compared to the previous version, it is reasonable to expect further efforts to identify the optimal training recipe and devset size. However, I will defer these optimizations to other contributors.
-------- Llama-2-7b-hf (2bit, HYB, fine-tune) ----------
wiki c4
qtip (paper) 5.86 7.73
qtip (reprod) 5.88 7.73
qtip (new kernel) 5.92 7.72