This repository contains the implementation of the NeurIPS 2023 paper:
Parameter and Computation Efficient Transfer Learning for Vision-Language Pre-trained Models [Paper]
Qiong Wu12, Wei Yu12, Yiyi Zhou12, Shubin Huang1, Xiaoshuai Sun12, Rongrong Ji12 1Media Analytics and Computing Lab, Department of Artificial Intelligence, School of Informatics, Xiamen University
2Institute of Artificial Intelligence, Xiamen University
In this paper, we aim at parameter and computation efficient transfer learning (PCETL) for VLP models. In particular, PCETL not only needs to limit the number of trainable parameters in VLP models, but also to reduce the computational redundancy during inference, thus enabling a more efficient transfer. To approach this target, we propose a novel dynamic architecture skipping (DAS) approach towards effective PCETL. DAS first observes the significances of their modules to downstream tasks via a reinforcement learning (RL) based process, and then skips the redundant ones with lightweight networks, i.e., adapters, according to the obtained rewards.
[24/10/27] Support for LLaVA is released.
cd LLaVA-DAS
conda create -n llava python=3.10 -y
conda activate llava
pip install --upgrade pip # enable PEP 660 support
pip install -e .
pip install -e ".[train]"
pip install flash-attn --no-build-isolation
The json files can be found at:
./LLaVA-DAS/json_data
The images can be downloaded from:
Slake:
https://www.med-vqa.com/slake/
AID:
https://captain-whu.github.io/AID/
pip install -r requirements.txt
pip install -e .
We follow ViLT and use pyarrow
to serialize the datasets. See this link for details.
cd LaVIN-DAS
conda create -n lavin python=3.8 -y
conda activate lavin
# install pytorch
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 -c pytorch
# install dependency and lavin
pip install -r requirements.txt
pip install -e .
Obtain the weights of LLaMA from this form (official) or Download LLaMA-7B
For ScienceQA, please prepare the dataset from the official repo.
For BoolQ, CommonSenseQA and gsm8k, please run:
pip install datasets
python OrgBoolQ.py
python OrgCommonSenseQA.py
python OrgGSM8K.py
The file structure should look like:
LaVIN-DAS/
|-- das
|-- scripts
|-- train.py
|-- eval.py
......
data/
|-- problem.json
|-- pid_splits.json
|-- captions.json
|-- all_data.json
|-- images
|-- train # ScienceQA train image
|-- val # ScienceQA val image
|-- test # ScienceQA test image
|-- weights
|-- tokenizer.model
|--7B
|-- params.json
|-- consolidated.00.pth
......
|-- BoolQ
|-- boolq_0_shot_test.json
|-- GSM8K
|-- gsm8k_0_shot_test.json
|-- CommonSenseQA
|-- commonsense_qa_0_shot_test.json
To search and finetuning the LLaVA, run:
cd LLaVA-DAS
sh scripts/{task_type}_DAS_{benchmark}.sh
The task_type includes {search} and {finetune}.
The benchmark includes {aid20} (AID), {aid50} (AID) and {slake} (Slake).
For evaluating, run:
sh scripts/v1_5/eval/{benchmark}.sh
Notification: The searched structure will be printed in the shell and needs to be manually filled in the following locations:
Finetuning: llava/train/train.py line 1011
Evaluating: llava/model/builder.py line 196
Add the path of checkpoint and 'skip_module' to vqa_eval.sh.
sh script/vqa_eval.sh
Work on the METER:
cd METER
Work on the ViLT:
cd ViLT
sh script/vqa_search.sh
Add search result to vqa_train.sh by additional parameter 'skip_module'.
sh script/vqa_train.sh
Add the path of checkpoint and 'skip_module' to vqa_eval.sh.
sh script/vqa_eval.sh
sh script/F30K_search.sh
Add search result to F30K_train.sh by additional parameter 'skip_module'.
sh script/F30K_train.sh
Add the path of checkpoint and 'skip_module' to F30K_eval.sh.
sh script/F30K_eval.sh
sh script/nlvr_search.sh
Add search result to F30K_train.sh by additional parameter 'skip_module'.
sh script/nlvr_train.sh
Add the path of checkpoint and 'skip_module' to nlvr_eval.sh.
sh script/nlvr_eval.sh
We also evaluate the experiment results on SceinceQA following LaVIN
Table 1: Comparison of DAS and PETL methods on ScienceQA for LLaMA.
Method | Update Params | Inference Time | Modality Natural | Modality Social | Modality Language | Context Text | Context Image | Context No | Grade G1-6 | Grade G7-12 | Avg |
---|---|---|---|---|---|---|---|---|---|---|---|
LaVIN-7B | 3.8M | 3.70s | 89.25 | 94.94 | 85.24 | 88.51 | 87.46 | 88.08 | 90.16 | 88.07 | 89.41 |
DAS2-7B | 4.2M | 3.44s | 88.68 | 94.94 | 86.45 | 88.03 | 86.81 | 88.92 | 90.20 | 88.00 | 89.41 |
DAS4-7B | 4.6M | 3.23s | 88.99 | 94.60 | 85.09 | 87.88 | 86.51 | 88.36 | 89.72 | 88.13 | 89.15 |
DAS6-7B | 5.0M | 3.06s | 87.30 | 93.36 | 82.36 | 86.12 | 85.97 | 85.71 | 88.18 | 85.70 | 87.29 |
To search and finetuning the LLaMA, run:
cd LaVIN-DAS
sh scripts/{task_type}_{benchmark}_7b.sh
The task_type includes {search}, {finetune} and {evaluate}.
The benchmark includes {boolq} (BoolQ), {csqa} (CommonSenceQA), {gsm8k} (GSM8K).
The code is based on ViLT licensed under Apache 2.0 and METER licensed under MIT and some of the code is borrowed from CLIP and Swin-Transformer.