From ee20413a1415fa763bf6bd618baa679668950a07 Mon Sep 17 00:00:00 2001 From: ankitade Date: Sun, 26 Jun 2022 07:46:47 +0000 Subject: [PATCH] [Flava] Add ckpt loading and accuracy metric to finetuning [ghstack-poisoned] --- examples/flava/configs/finetuning/qnli.yaml | 4 ++- examples/flava/finetune.py | 25 ++++++++++++------ examples/flava/model.py | 28 +++++++++++++++++---- torchmultimodal/models/flava/flava_model.py | 10 +++++++- torchmultimodal/utils/common.py | 3 ++- 5 files changed, 54 insertions(+), 16 deletions(-) diff --git a/examples/flava/configs/finetuning/qnli.yaml b/examples/flava/configs/finetuning/qnli.yaml index dd2659cc..1b1fc1fb 100644 --- a/examples/flava/configs/finetuning/qnli.yaml +++ b/examples/flava/configs/finetuning/qnli.yaml @@ -4,7 +4,7 @@ training: _target_: flava.definitions.TrainingArguments lightning: max_steps: 33112 - gpus: -1 + gpus: 1 progress_bar_refresh_rate: 50 val_check_interval: 1000 num_sanity_val_steps: 0 @@ -16,6 +16,8 @@ training: every_n_train_steps: 1000 save_on_train_epoch_end: true verbose: true + monitor: validation/accuracy/classification + mode: max lightning_load_from_checkpoint: null seed: -1 batch_size: 32 diff --git a/examples/flava/finetune.py b/examples/flava/finetune.py index 8078051b..d753185b 100644 --- a/examples/flava/finetune.py +++ b/examples/flava/finetune.py @@ -8,10 +8,10 @@ from flava.data.datamodules import VLDataModule from flava.definitions import FLAVAArguments from flava.model import FLAVAClassificationLightningModule +from flava.utils import build_config, build_datamodule_kwargs from omegaconf import OmegaConf from pytorch_lightning import seed_everything, Trainer -from pytorch_lightning.callbacks import LearningRateMonitor -from utils import build_config, build_datamodule_kwargs +from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint AVAIL_GPUS = 1 SEED = -1 @@ -55,14 +55,23 @@ def main(): **config.model, ) + callbacks = [ + LearningRateMonitor(logging_interval="step"), + ] + + if config.training.lightning_checkpoint is not None: + callbacks.append( + ModelCheckpoint( + **OmegaConf.to_container(config.training.lightning_checkpoint) + ) + ) + trainer = Trainer( - **OmegaConf.to_container(config.training.lightning), - callbacks=[ - LearningRateMonitor(logging_interval="step"), - ], + **OmegaConf.to_container(config.training.lightning), callbacks=callbacks ) - trainer.fit(model, datamodule=datamodule) - trainer.validate(model, datamodule=datamodule) + ckpt_path = config.training.lightning_load_from_checkpoint + trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path) + trainer.validate(datamodule=datamodule) if __name__ == "__main__": diff --git a/examples/flava/model.py b/examples/flava/model.py index 29baad71..e308c3e3 100644 --- a/examples/flava/model.py +++ b/examples/flava/model.py @@ -8,6 +8,7 @@ import torch from pytorch_lightning import LightningModule +from torchmetrics import Accuracy from torchmultimodal.models.flava.flava_model import ( flava_model_for_classification, flava_model_for_pretraining, @@ -139,18 +140,33 @@ def __init__( self.warmup_steps = warmup_steps self.max_steps = max_steps self.adam_betas = adam_betas + self.metrics = Accuracy() def training_step(self, batch, batch_idx): - output = self._step(batch, batch_idx) + output, accuracy = self._step(batch, batch_idx) self.log("train/losses/classification", output.loss, prog_bar=True, logger=True) + self.log( + "train/accuracy/classification", + accuracy, + prog_bar=True, + logger=True, + sync_dist=True, + ) return output.loss def validation_step(self, batch, batch_idx): - output = self._step(batch, batch_idx) + output, accuracy = self._step(batch, batch_idx) self.log( "validation/losses/classification", output.loss, prog_bar=True, logger=True ) + self.log( + "validation/accuracy/classification", + accuracy, + prog_bar=True, + logger=True, + sync_dist=True, + ) return output.loss @@ -164,15 +180,17 @@ def _step(self, batch, batch_idx): else: raise RuntimeError("Batch needs to have either or both 'image' and 'text'.") + labels = batch.get("labels", None) output = self.model( image=batch.get("image", None), text=batch.get("text", None), required_embedding=required_embedding, - labels=batch.get("labels", None), + labels=labels, ) + if labels is not None: + accuracy = self.metrics(output.logits, labels) - # TODO: Add accuracy metric to this later. - return output + return output, accuracy def configure_optimizers(self): return get_optimizers_for_lightning( diff --git a/torchmultimodal/models/flava/flava_model.py b/torchmultimodal/models/flava/flava_model.py index 9c74f527..8d7f0b6d 100644 --- a/torchmultimodal/models/flava/flava_model.py +++ b/torchmultimodal/models/flava/flava_model.py @@ -209,6 +209,7 @@ def flava_model_for_classification( classifier_activation: Callable[..., nn.Module] = nn.ReLU, classifier_normalization: Optional[Callable[..., nn.Module]] = None, loss_fn: Optional[Callable[..., Tensor]] = None, + pretrained_model_key: Optional[str] = "flava_full", **flava_model_kwargs: Any, ): model = flava_model(**flava_model_kwargs) @@ -224,7 +225,14 @@ def flava_model_for_classification( if loss_fn is None: loss_fn = nn.CrossEntropyLoss() - return FLAVAForClassification(model=model, classifier=classifier, loss=loss_fn) + classification_model = FLAVAForClassification( + model=model, classifier=classifier, loss=loss_fn + ) + if pretrained_model_key is not None: + classification_model.load_model( + FLAVA_FOR_PRETRAINED_MAPPING[pretrained_model_key], strict=False + ) + return classification_model def to_2tuple(x): diff --git a/torchmultimodal/utils/common.py b/torchmultimodal/utils/common.py index a52910dd..0caef042 100644 --- a/torchmultimodal/utils/common.py +++ b/torchmultimodal/utils/common.py @@ -146,6 +146,7 @@ def load_model( pretrained_url: Optional[str], load_state_dict: bool = True, state_dict_key: Optional[str] = None, + strict: bool = True, ): assert isinstance( self, torch.nn.Module @@ -160,7 +161,7 @@ def load_model( state_dict = state_dict[state_dict_key] if load_state_dict: - self.load_state_dict(state_dict) + self.load_state_dict(state_dict, strict=strict) return state_dict