diff --git a/scripts/clip.sh b/scripts/clip.sh new file mode 100644 index 0000000..b878d08 --- /dev/null +++ b/scripts/clip.sh @@ -0,0 +1,20 @@ +cd /mmfs1/gscratch/krishna/mayank/clip_clone/open_clip/src +torchrun --nproc_per_node 2 --master_port 3467 -m training.main \ + --model "ViT-B-16" \ + --train-data "/mmfs1/data/yfcc-tmp/cc_3m/train_shards/shard_{000000..003318}.tar" \ + --imagenet-val "/mmfs1/data/yfcc-tmp/imagenet/val/" \ + --dataset-type webdataset \ + --precision amp \ + --gather-with-grad \ + --local-loss \ + --batch-size 512 \ + --accum-freq 1 \ + --workers 4 \ + --epochs 40 \ + --warmup 4000 \ + --zeroshot-frequency 2 \ + --seed 0 \ + --report-to 'wandb' \ + --wandb-project-name "mrl_clip_training" \ + --logs "/mmfs1/gscratch/krishna/mayank/clip_clone/open_clip/src/logs/mrl_clip" \ + --name "clip_b512_accum_1_ep40_bugfixed" \ No newline at end of file diff --git a/scripts/finetuning_original.sh b/scripts/finetuning_original.sh new file mode 100644 index 0000000..9382e4b --- /dev/null +++ b/scripts/finetuning_original.sh @@ -0,0 +1,25 @@ +cd /mmfs1/gscratch/krishna/mayank/clip_clone/open_clip/src +torchrun --nproc_per_node 2 --master_port 4556 -m training.main \ + --model "ViT-B-16" \ + --pretrained "laion400m_e32" \ + --train-data "/mmfs1/data/yfcc-tmp/cc_3m/train_shards/shard_{000000..003318}.tar" \ + --imagenet-val "/mmfs1/data/yfcc-tmp/imagenet/val/" \ + --dataset-type webdataset \ + --precision amp \ + --gather-with-grad \ + --local-loss \ + --force_mrl_loss \ + --mrl_loss_weights "1,1,1,1,1" \ + --mrl_dim_to_consider "768,384,192,96,48" \ + --accum-freq 1 \ + --batch-size 512 \ + --lr 1e-07 \ + --workers 4 \ + --epochs 10 \ + --warmup 500 \ + --zeroshot-frequency 1 \ + --seed 0 \ + --report-to 'wandb' \ + --wandb-project-name "mrl_clip_training" \ + --logs "/mmfs1/gscratch/krishna/mayank/clip_clone/open_clip/src/logs/mrl_clip" \ + --name "ViT-B-16_liaon400m_e32_finetune_mrl_ep10_warmup_500_lr1e-07" \ No newline at end of file diff --git a/scripts/finetuning_weighted.sh b/scripts/finetuning_weighted.sh new file mode 100644 index 0000000..6b5ada3 --- /dev/null +++ b/scripts/finetuning_weighted.sh @@ -0,0 +1,25 @@ +cd /mmfs1/gscratch/krishna/mayank/clip_clone/open_clip/src +torchrun --nproc_per_node 2 --master_port 4534 -m training.main \ + --model "ViT-B-32" \ + --pretrained "laion400m_e32" \ + --train-data "/mmfs1/data/yfcc-tmp/cc_3m/train_shards/shard_{000000..003318}.tar" \ + --imagenet-val "/mmfs1/data/yfcc-tmp/imagenet/val/" \ + --dataset-type webdataset \ + --precision amp \ + --gather-with-grad \ + --local-loss \ + --force_mrl_loss \ + --mrl_loss_weights "0.3,0.25,0.2,0.15,0.1" \ + --mrl_dim_to_consider "768,384,192,96,48" \ + --accum-freq 1 \ + --batch-size 512 \ + --lr 1e-07 \ + --workers 4 \ + --epochs 10 \ + --warmup 500 \ + --zeroshot-frequency 1 \ + --seed 0 \ + --report-to 'wandb' \ + --wandb-project-name "mrl_clip_training" \ + --logs "/mmfs1/gscratch/krishna/mayank/clip_clone/open_clip/src/logs/mrl_clip" \ + --name "ViT-B-16_liaon400m_e32_finetune_mrl_ep10_warmup_500_wl030250201501_lr1e-07" \ No newline at end of file diff --git a/scripts/launch_this.slurm b/scripts/launch_this.slurm new file mode 100644 index 0000000..41dce10 --- /dev/null +++ b/scripts/launch_this.slurm @@ -0,0 +1,16 @@ +#!/bin/bash + +#SBATCH --account=krishna +#SBATCH --partition=gpu-a100 +#SBATCH --job-name=MClip1 +#SBATCH --error=./jobs/mrl_clip/job.err.%j +#SBATCH --output=./jobs/mrl_clip/job.run.%j +#SBATCH --time=10-00:00:00 +#SBATCH --cpus-per-task=6 +#SBATCH --gpus=2 +#SBATCH --mem-per-gpu=80G +#SBATCH --nodes=1 +#SBATCH --mail-user=krmayank@uw.edu +#SBATCH --mail-type=ALL + +bash mrl_clip.sh \ No newline at end of file diff --git a/scripts/mrl_clip.sh b/scripts/mrl_clip.sh new file mode 100644 index 0000000..45b9ac4 --- /dev/null +++ b/scripts/mrl_clip.sh @@ -0,0 +1,26 @@ +cd ~/open_clip/src +#torchrun --nproc_per_node 2 --master_port 3233 -m training.main \ +#lightning run model training/main.py \ +python3 -m training.main \ + --model "ViT-B-32" \ + --train-data "/home/krmayank/data/medium/data/{00000000..00001919}.tar" \ + --imagenet-val="/home/krmayank/data/imagenet/imagenet_val/" + --dataset-type webdataset \ + --precision amp \ + --gather-with-grad \ + --local-loss \ + --force_mrl_loss \ + --mrl_loss_weights "1,1,1,1,1" \ + --mrl_dim_to_consider "768,384,192,96,48" \ + --batch-size 128 \ + --accum-freq 1 \ + --workers 4 \ + --epochs 3 \ + --warmup 4000 \ + --zeroshot-frequency 1 \ + --seed 1234 \ + --num_train_samples 10000 + #--report-to 'wandb' \ + #--wandb-project-name "mrl_clip_training" \ + #--logs "/mmfs1/gscratch/krishna/mayank/clip_clone/open_clip/src/logs/mrl_clip" \ + #--name "mrl_clip_b512_accum_1_ep40_diffLogitScale_D082723_wl010150202503" diff --git a/scripts/resume.sh b/scripts/resume.sh new file mode 100644 index 0000000..194c212 --- /dev/null +++ b/scripts/resume.sh @@ -0,0 +1,25 @@ +cd /mmfs1/gscratch/krishna/mayank/clip_clone/open_clip/src +torchrun --nproc_per_node 2 -m --master_port 3436 training.main \ + --model "ViT-B-16" \ + --train-data "/mmfs1/data/yfcc-tmp/cc_3m/train_shards/shard_{000000..003318}.tar" \ + --imagenet-val "/mmfs1/data/yfcc-tmp/imagenet/val/" \ + --dataset-type webdataset \ + --precision amp \ + --gather-with-grad \ + --local-loss \ + --accum-freq 1 \ + --workers 4 \ + --epochs 60 \ + --warmup 1000 \ + --zeroshot-frequency 1 \ + --seed 0 \ + --resume "/mmfs1/gscratch/krishna/mayank/clip_clone/open_clip/src/logs/mrl_clip/clip_b512_accum_1_ep40_bugfixed/checkpoints/epoch_40.pt" \ + --report-to 'wandb' \ + --wandb-project-name "mrl_clip_training" \ + --logs "/mmfs1/gscratch/krishna/mayank/clip_clone/open_clip/src/logs/mrl_clip" \ + --name "clip_b512_accum_1_ep40_resumefrom40" + + + # --force_mrl_loss \ + # --mrl_loss_weights "1,1,1,1,1" \ + # --mrl_dim_to_consider "768,384,192,96,48" \ \ No newline at end of file diff --git a/scripts/training_tpu.sh b/scripts/training_tpu.sh new file mode 100644 index 0000000..415354a --- /dev/null +++ b/scripts/training_tpu.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +export PJRT_DEVICE=TPU +#export XLA_IR_DEBUG=1 +#export XLA_METRICS_FILE=1 +#sudo kill -9 $(sudo lsof -w /dev/accel0 | awk 'NR>1 {print $2}' | uniq) + +python -c "import os; os.environ.pop('LD_PRELOAD', None)" + +cd ~/open_clip/src/ +python3 -m training.main \ + --model "ViT-B-32" \ + --train-data "/home/krmayank/data/medium/data/{00000000..00001919}.tar" \ + --imagenet-val="/home/krmayank/data/imagenet/imagenet_val/" \ + --dataset-type webdataset \ + --precision amp \ + --gather-with-grad \ + --local-loss \ + --force_mrl_loss \ + --mrl_loss_weights "1,1,1,1,1" \ + --mrl_dim_to_consider "768,384,192,96,48" \ + --batch-size 128 \ + --accum-freq 1 \ + --workers 4 \ + --epochs 3 \ + --warmup 4 \ + --zeroshot-frequency 1 \ + --seed 1234 \ + --gather-with-grad \ + --train-num-samples 10000 \ + --val-num-samples 10000 +# --report-to 'wandb' \ +# --wandb-project-name "mrl_clip_training" +# --val-data "/home/krmayank/data/medium/data/{00000000..00000000}.tar" \ +# --imagenet-val="/home/krmayank/data/imagenet/imagenet_val/" \ diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index 14011f9..b23d2bf 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -13,7 +13,7 @@ from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\ resize_pos_embed, get_cast_dtype from .coca_model import CoCa -from .loss import ClipLoss, DistillClipLoss, CoCaLoss +from .loss import ClipLoss, DistillClipLoss, CoCaLoss, MRLClipLoss from .openai import load_openai_model from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model, download_pretrained_from_hf from .transform import image_transform, AugmentationCfg @@ -119,6 +119,8 @@ def create_model( cache_dir: Optional[str] = None, output_dict: Optional[bool] = None, require_pretrained: bool = False, + use_mrl: bool = False, + mrl_dim: int = 0 ): has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX) if has_hf_hub_prefix: @@ -191,7 +193,7 @@ def create_model( else: model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) else: - model = CLIP(**model_cfg, cast_dtype=cast_dtype) + model = CLIP(**model_cfg, cast_dtype=cast_dtype, use_mrl=use_mrl, mrl_dim=mrl_dim) # intialise with mrl based config if use_mrl pretrained_loaded = False if pretrained: @@ -204,7 +206,10 @@ def create_model( if checkpoint_path: logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') - load_checkpoint(model, checkpoint_path) + if use_mrl: + load_checkpoint(model, checkpoint_path, strict=False) + else: + load_checkpoint(model, checkpoint_path) else: error_str = ( f'Pretrained weights ({pretrained}) not found for model {model_name}.' @@ -261,6 +266,17 @@ def create_loss(args): world_size=args.world_size, use_horovod=args.horovod, ) + elif args.force_mrl_loss: + return MRLClipLoss( + local_loss=args.local_loss, + gather_with_grad=args.gather_with_grad, + cache_labels=True, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod, + mrl_loss_weights=args.mrl_loss_weights, + dim_to_consider=args.mrl_dim_to_consider + ) return ClipLoss( local_loss=args.local_loss, gather_with_grad=args.gather_with_grad, @@ -288,6 +304,8 @@ def create_model_and_transforms( aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, cache_dir: Optional[str] = None, output_dict: Optional[bool] = None, + use_mrl: Optional[bool] = False, + mrl_dim: int = 0 ): model = create_model( model_name, @@ -303,6 +321,8 @@ def create_model_and_transforms( pretrained_hf=pretrained_hf, cache_dir=cache_dir, output_dict=output_dict, + use_mrl=use_mrl, + mrl_dim = mrl_dim ) image_mean = image_mean or getattr(model.visual, 'image_mean', None) @@ -337,6 +357,7 @@ def create_model_from_pretrained( image_mean: Optional[Tuple[float, ...]] = None, image_std: Optional[Tuple[float, ...]] = None, cache_dir: Optional[str] = None, + use_mrl = False ): model = create_model( model_name, @@ -349,6 +370,7 @@ def create_model_from_pretrained( force_image_size=force_image_size, cache_dir=cache_dir, require_pretrained=True, + use_mrl=use_mrl ) if not return_transform: diff --git a/src/open_clip/loss.py b/src/open_clip/loss.py index 4fbf61d..4580c12 100644 --- a/src/open_clip/loss.py +++ b/src/open_clip/loss.py @@ -15,6 +15,11 @@ except ImportError: hvd = None +try: + import torch_xla.core.xla_model as xm +except ImportError: + xm = None + def gather_features( image_features, @@ -210,3 +215,52 @@ def forward( return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss} return contrastive_loss, distill_loss + +class MRLClipLoss(ClipLoss): + def __init__(self, + local_loss=False, + gather_with_grad=False, + cache_labels=False, + rank=0, + world_size=1, + use_horovod=False, + mrl_loss_weights = None, + dim_to_consider=None): + super().__init__(local_loss, + gather_with_grad, + cache_labels, + rank, + world_size, + use_horovod) + self.mrl_loss_weights = mrl_loss_weights + self.dim_to_consider = dim_to_consider + + def normalize(self, features, dim=-1): + if xm.xla_device(): + norm = xm.all_reduce("sum", features ** 2) + norm = torch.sqrt(norm) + features = features / norm + return features + return F.normalize(features, dim=dim) + + def forward(self, image_features, text_features, logit_scale, output_dict=False): + # print("Inside forward of MRL CLIP loss",len(self.mrl_loss_weights), self.mrl_loss_weights ) + + assert len(self.mrl_loss_weights) == len(self.dim_to_consider), "number of elements in loss weights and dim_to_consider should be same" + + # dim_to_consider = [768, 384, 192, 96, 48] + # total_loss = 0 + loss_list = [] + for idx, dim in enumerate(self.dim_to_consider): + # img = self.normalize(image_features[:,:dim], dim=-1) # slice and normalize + # txt = self.normalize(text_features[:,:dim], dim=-1) # slice and normalize + img = F.normalize(image_features[:,:dim], dim=-1) # slice and normalize + txt = F.normalize(text_features[:,:dim], dim=-1) # slice and normalize + loss = super().forward(image_features=img, text_features=txt, logit_scale=logit_scale[idx]) + # total_loss += self.mrl_loss_weights[idx] * loss + loss_list.append(self.mrl_loss_weights[idx] * loss) + + if output_dict: + return {f"mrl_clip_loss_{key}": value for key, value in zip(self.dim_to_consider, loss_list)} + + return loss_list diff --git a/src/open_clip/model.py b/src/open_clip/model.py index 4f5e775..991fd60 100644 --- a/src/open_clip/model.py +++ b/src/open_clip/model.py @@ -19,6 +19,10 @@ from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer from .utils import to_2tuple +try: + import torch_xla.core.xla_model as xm +except ImportError: + pass @dataclass class CLIPVisionCfg: @@ -184,9 +188,13 @@ def __init__( quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None, output_dict: bool = False, + use_mrl: bool = False, + mrl_dim: int = 0, ): super().__init__() self.output_dict = output_dict + self.use_mrl = use_mrl + self.mrl_dim = mrl_dim self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) @@ -197,8 +205,12 @@ def __init__( self.ln_final = text.ln_final self.text_projection = text.text_projection self.register_buffer('attn_mask', text.attn_mask, persistent=False) - - self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + if use_mrl: + # print("using mrl inside clip initialization") + self.logit_scale = nn.ParameterList([nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) for _ in range(self.mrl_dim)]) + else: + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 @@ -228,15 +240,29 @@ def encode_text(self, text, normalize: bool = False): return F.normalize(x, dim=-1) if normalize else x def forward(self, image, text): - image_features = self.encode_image(image, normalize=True) - text_features = self.encode_text(text, normalize=True) - if self.output_dict: - return { - "image_features": image_features, - "text_features": text_features, - "logit_scale": self.logit_scale.exp() - } - return image_features, text_features, self.logit_scale.exp() + if self.use_mrl: + # print("inside CLIP main class, using MRL") + image_features = self.encode_image(image, normalize=False) + text_features = self.encode_text(text, normalize=False) + + if self.output_dict: + return { + "image_features": image_features, + "text_features": text_features, + "logit_scale": [scale.exp() for scale in self.logit_scale] + } + return image_features, text_features, [scale.exp() for scale in self.logit_scale] + else: + image_features = self.encode_image(image, normalize=True) + text_features = self.encode_text(text, normalize=True) + if self.output_dict: + return { + "image_features": image_features, + "text_features": text_features, + "logit_scale": self.logit_scale.exp() + } + return image_features, text_features, self.logit_scale.exp() + class CustomTextCLIP(nn.Module): diff --git a/src/training/distributed.py b/src/training/distributed.py index 268a6c7..5f44dd1 100644 --- a/src/training/distributed.py +++ b/src/training/distributed.py @@ -8,6 +8,10 @@ except ImportError: hvd = None +try: + import torch_xla.core.xla_model as xm +except ImportError: + xm = None def is_global_master(args): return args.rank == 0 @@ -18,6 +22,8 @@ def is_local_master(args): def is_master(args, local=False): + if xm is not None: + return xm.is_master_ordinal() return is_local_master(args) if local else is_global_master(args) @@ -39,7 +45,6 @@ def is_using_distributed(): return int(os.environ['SLURM_NTASKS']) > 1 return False - def world_info_from_env(): local_rank = 0 for v in ('LOCAL_RANK', 'MPI_LOCALRANKID', 'SLURM_LOCALID', 'OMPI_COMM_WORLD_LOCAL_RANK'): @@ -107,6 +112,8 @@ def init_distributed_device(args): else: device = 'cuda:0' torch.cuda.set_device(device) + elif xm is not None: + device = xm.xla_device() else: device = 'cpu' args.device = device @@ -118,6 +125,8 @@ def broadcast_object(args, obj, src=0): # broadcast a pickle-able python object from rank-0 to all ranks if args.horovod: return hvd.broadcast_object(obj, root_rank=src) + elif xm is not None: + return xm.mesh_reduce('broadcast_object', obj, lambda x: x, src) else: if args.rank == src: objects = [obj] @@ -131,6 +140,8 @@ def all_gather_object(args, obj, dst=0): # gather a pickle-able python object across all ranks if args.horovod: return hvd.allgather_object(obj) + elif xm is not None: + return xm.mesh_reduce('all_gather_object', obj, lambda x: x, dst) else: objects = [None for _ in range(args.world_size)] dist.all_gather_object(objects, obj) diff --git a/src/training/main.py b/src/training/main.py index f70c9f9..a49e2de 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -1,3 +1,5 @@ +# modify this file to support distributed training on tpus and gpus +# this file is a modified version of the original main.py file from the open_clip repo import glob import logging import os @@ -12,6 +14,15 @@ from torch import optim from torch.cuda.amp import GradScaler +try: + import torch_xla.core.xla_model as xm + import torch_xla.distributed.parallel_loader as pl + import torch_xla.distributed.xla_multiprocessing as xmp +except ImportError: + xm = None + pl = None + xmp = None + try: import wandb except ImportError: @@ -67,9 +78,9 @@ def get_latest_checkpoint(path: str, remote : bool): return None -def main(args): +def main(index, args): args = parse_args(args) - + # check if args.use_tpu is set, if so, use torch_xla if torch.cuda.is_available(): # This enables tf32 on Ampere GPUs which is only 8% slower than # float16 and almost as accurate as float32 @@ -77,6 +88,11 @@ def main(args): torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.benchmark = True torch.backends.cudnn.deterministic = False + elif args.use_tpu: + assert xm is not None, "Please install torch_xla to use TPUs." + xm.set_rng_state(args.seed, device=xm.xla_device()) + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = False # fully initialize distributed device environment device = init_distributed_device(args) @@ -230,6 +246,8 @@ def main(args): image_std=args.image_std, aug_cfg=args.aug_cfg, output_dict=True, + use_mrl=args.force_mrl_loss, + mrl_dim = len(args.mrl_dim_to_consider) ) if args.distill: # FIXME: currenlty assumes the model your distilling from has the same tokenizer & transforms. @@ -310,7 +328,7 @@ def main(args): hvd.broadcast_parameters(model.state_dict(), root_rank=0) hvd.broadcast_optimizer_state(optimizer, root_rank=0) - scaler = GradScaler() if args.precision == "amp" else None + scaler = GradScaler() if args.precision == "amp" and not xm.xla_device() else None # optionally resume from a checkpoint start_epoch = 0 @@ -334,7 +352,14 @@ def main(args): logging.info(f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})") # initialize datasets - data = get_data(args, (preprocess_train, preprocess_val), epoch=start_epoch, tokenizer=get_tokenizer(args.model)) + if args.use_tpu: + if not xm.is_master_ordinal(): + xm.rendezvous('download_only_once') + data = get_data(args, (preprocess_train, preprocess_val), epoch=start_epoch, tokenizer=get_tokenizer(args.model)) + if xm.is_master_ordinal(): + xm.rendezvous('download_only_once') + else: + data = get_data(args, (preprocess_train, preprocess_val), epoch=start_epoch, tokenizer=get_tokenizer(args.model)) assert len(data), 'At least one train or eval dataset must be specified.' # create scheduler if train @@ -390,7 +415,7 @@ def main(args): return loss = create_loss(args) - + for epoch in range(start_epoch, args.epochs): if is_master(args): logging.info(f'Start epoch {epoch}') @@ -466,5 +491,13 @@ def copy_codebase(args): return 1 +def run_tpu(): + # Spawn 8 processes + xmp.spawn(main, args=(sys.argv[1:],)) + + if __name__ == "__main__": - main(sys.argv[1:]) + if xmp is not None: + run_tpu() + else: + main(sys.argv[1:]) diff --git a/src/training/params.py b/src/training/params.py index 36c693b..f480b20 100644 --- a/src/training/params.py +++ b/src/training/params.py @@ -22,6 +22,11 @@ def __call__(self, parser, namespace, values, option_string=None): kw[key] = str(value) # fallback to string (avoid need to escape on command line) setattr(namespace, self.dest, kw) +def parse_mrl_loss_weights(weights_str): + return [float(x) for x in weights_str.split(",")] + +def parse_dim_to_consider(dim_to_consider): + return [int(x) for x in dim_to_consider.split(",")] def parse_args(args): parser = argparse.ArgumentParser() @@ -424,6 +429,30 @@ def parse_args(args): default=None, help='Which pre-trained weights to distill from, if any.' ) + parser.add_argument( + "--force_mrl_loss", + default=False, + action="store_true", + help="whether to use MRL based loss" + ) + parser.add_argument( + "--mrl_loss_weights", + default=[1,1,1,1,1], + type=parse_mrl_loss_weights, + help="weights for loss weights, dimensions are considered in following order 8, 16, 32, 64, 128, 256, 512, 768" + ) + parser.add_argument( + "--mrl_dim_to_consider", + default=[768, 384, 192, 96, 48], + type=parse_dim_to_consider, + help="weights for loss weights, dimensions are considered in following order 8, 16, 32, 64, 128, 256, 512, 768" + ) + parser.add_argument( + "--use_tpu", + default=False, + action="store_true", + help="Weather to use TPUs for training" + ) args = parser.parse_args(args) # If some params are not passed, we use the default values based on model name. diff --git a/src/training/train.py b/src/training/train.py index e0a140f..b2d5a50 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -8,12 +8,20 @@ import torch import torch.nn.functional as F from torch.nn.parallel.distributed import DistributedDataParallel +import contextlib try: import wandb except ImportError: wandb = None +try: + import torch_xla.core.xla_model as xm + import torch_xla.distributed.parallel_loader as pl +except ImportError: + xm = None + pl = None + from open_clip import get_cast_dtype, CLIP, CustomTextCLIP from .distributed import is_master from .zero_shot import zero_shot_eval @@ -61,7 +69,7 @@ def backward(total_loss, scaler): def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist_model, args, tb_writer=None): device = torch.device(args.device) - autocast = get_autocast(args.precision) + autocast = get_autocast(args.precision) if not args.use_tpu else contextlib.nullcontext cast_dtype = get_cast_dtype(args.precision) @@ -72,7 +80,10 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist data['train'].set_epoch(epoch) # set epoch in process safe manner via sampler or shared_epoch dataloader = data['train'].dataloader num_batches_per_epoch = dataloader.num_batches // args.accum_freq + samples_per_epoch = dataloader.num_samples sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10)) + if args.use_tpu: + dataloader = pl.ParallelLoader(dataloader, [device]).per_device_loader(device) if args.accum_freq > 1: accum_images, accum_texts, accum_features = [], [], {} @@ -166,7 +177,10 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist else: if args.grad_clip_norm is not None: torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0) - optimizer.step() + if args.use_tpu: + xm.optimizer_step(optimizer) + else: + optimizer.step() # reset gradient accum, if enabled if args.accum_freq > 1: @@ -174,7 +188,11 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist # Note: we clamp to 4.6052 = ln(100), as in the original paper. with torch.no_grad(): - unwrap_model(model).logit_scale.clamp_(0, math.log(100)) + if args.force_mrl_loss: + for idx in range(len(args.mrl_dim_to_consider)): + unwrap_model(model).logit_scale[idx].clamp_(0, math.log(100)) + else: + unwrap_model(model).logit_scale.clamp_(0, math.log(100)) batch_time_m.update(time.time() - end) end = time.time() @@ -182,16 +200,24 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist if is_master(args) and (i_accum % args.log_every_n_steps == 0 or batch_count == num_batches_per_epoch): batch_size = len(images) num_samples = batch_count * batch_size * args.accum_freq * args.world_size - samples_per_epoch = dataloader.num_samples + # samples_per_epoch = dataloader.num_samples percent_complete = 100.0 * batch_count / num_batches_per_epoch # NOTE loss is coarsely sampled, just master node and per log update for key, val in losses.items(): if key not in losses_m: losses_m[key] = AverageMeter() - losses_m[key].update(val.item(), batch_size) - - logit_scale_scalar = logit_scale.item() + if args.use_tpu: + losses_m[key].update(val, batch_size) + else: + losses_m[key].update(val.item(), batch_size) + + if args.force_mrl_loss: + logit_scale_scalar = [logit_scale[i].item() for i in range(len(args.mrl_dim_to_consider))] + logit_scale_string = " ,".join([f"{item:.3f}" for item in logit_scale_scalar]) + else: + logit_scale_scalar = logit_scale.item() + logit_scale_string = f"{logit_scale_scalar:.3f}" loss_log = " ".join( [ f"{loss_name.capitalize()}: {loss_m.val:#.5g} ({loss_m.avg:#.5g})" @@ -205,18 +231,29 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist f"Data (t): {data_time_m.avg:.3f} " f"Batch (t): {batch_time_m.avg:.3f}, {samples_per_second:#g}/s, {samples_per_second_per_gpu:#g}/s/gpu " f"LR: {optimizer.param_groups[0]['lr']:5f} " - f"Logit Scale: {logit_scale_scalar:.3f} " + loss_log + f"Logit Scale: {logit_scale_string} " + loss_log ) # Save train loss / etc. Using non avg meter values as loggers have their own smoothing - log_data = { - "data_time": data_time_m.val, - "batch_time": batch_time_m.val, - "samples_per_second": samples_per_second, - "samples_per_second_per_gpu": samples_per_second_per_gpu, - "scale": logit_scale_scalar, - "lr": optimizer.param_groups[0]["lr"] - } + if args.force_mrl_loss: + log_data = { + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + "samples_per_second": samples_per_second, + "samples_per_second_per_gpu": samples_per_second_per_gpu, + "lr": optimizer.param_groups[0]["lr"] + } + log_data.update({f"scale_{dim}": logit_scale_scalar[idx] for idx, dim in enumerate(args.mrl_dim_to_consider)}) + else: + log_data = { + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + "samples_per_second": samples_per_second, + "samples_per_second_per_gpu": samples_per_second_per_gpu, + "scale": logit_scale_scalar, + "lr": optimizer.param_groups[0]["lr"] + } + log_data.update({name:val.val for name,val in losses_m.items()}) for name, val in log_data.items(): @@ -243,7 +280,7 @@ def evaluate(model, data, epoch, args, tb_writer=None): zero_shot_metrics = zero_shot_eval(model, data, epoch, args) metrics.update(zero_shot_metrics) - autocast = get_autocast(args.precision) + autocast = get_autocast(args.precision) if not args.use_tpu else contextlib.nullcontext cast_dtype = get_cast_dtype(args.precision) if 'val' in data and (args.val_frequency and ((epoch % args.val_frequency) == 0 or epoch == args.epochs)): @@ -251,6 +288,9 @@ def evaluate(model, data, epoch, args, tb_writer=None): num_samples = 0 samples_per_val = dataloader.num_samples + if args.use_tpu: + dataloader = pl.ParallelLoader(dataloader, [device]).per_device_loader(device) + # FIXME this does not scale past small eval datasets # all_image_features @ all_text_features will blow up memory and compute very quickly cumulative_loss = 0.0 @@ -271,7 +311,10 @@ def evaluate(model, data, epoch, args, tb_writer=None): # however, system RAM is easily exceeded and compute time becomes problematic all_image_features.append(image_features.cpu()) all_text_features.append(text_features.cpu()) - logit_scale = logit_scale.mean() + if not args.force_mrl_loss: + logit_scale = logit_scale.mean() + else: + logit_scale = sum(logit_scale)/len(logit_scale) logits_per_image = logit_scale * image_features @ text_features.t() logits_per_text = logits_per_image.t() diff --git a/src/training/zero_shot.py b/src/training/zero_shot.py index e5768b4..9f062a1 100644 --- a/src/training/zero_shot.py +++ b/src/training/zero_shot.py @@ -3,14 +3,19 @@ import torch import torch.nn.functional as F from tqdm import tqdm +import contextlib +import torch_xla.core.xla_model as xm +import torch_xla.distributed.parallel_loader as pl from open_clip import get_cast_dtype, get_tokenizer from .precision import get_autocast from .imagenet_zeroshot_data import imagenet_classnames, openai_imagenet_template -def zero_shot_classifier(model, classnames, templates, args): +def zero_shot_classifier(model, classnames, templates, dim, args): + xm.mark_step() tokenizer = get_tokenizer(args.model) + xm.mark_step() with torch.no_grad(): zeroshot_weights = [] for classname in tqdm(classnames): @@ -20,10 +25,14 @@ def zero_shot_classifier(model, classnames, templates, args): class_embeddings = model.module.encode_text(texts) else: class_embeddings = model.encode_text(texts) + if args.force_mrl_loss: + class_embeddings = class_embeddings[:,:dim] # consider only dim for creating zeroshot classifier class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) class_embedding /= class_embedding.norm() zeroshot_weights.append(class_embedding) + xm.mark_step() zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(args.device) + xm.mark_step() return zeroshot_weights @@ -33,9 +42,15 @@ def accuracy(output, target, topk=(1,)): return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk] -def run(model, classifier, dataloader, args): - autocast = get_autocast(args.precision) - cast_dtype = get_cast_dtype(args.precision) +def run(model, classifier, dataloader, dim, args): + autocast = get_autocast(args.precision) if not args.use_tpu else contextlib.nullcontext + cast_dtype = get_cast_dtype(args.precision) + + if args.use_tpu: + device = xm.xla_device() + para_loader = pl.ParallelLoader(dataloader, [device]) + dataloader = para_loader.per_device_loader(device) + with torch.no_grad(): top1, top5, n = 0., 0., 0. for images, target in tqdm(dataloader, unit_scale=args.batch_size): @@ -50,11 +65,15 @@ def run(model, classifier, dataloader, args): image_features = model.module.encode_image(images) else: image_features = model.encode_image(images) + if args.force_mrl_loss: + image_features = image_features[:,:dim] image_features = F.normalize(image_features, dim=-1) logits = 100. * image_features @ classifier # measure accuracy + xm.mark_step() acc1, acc5 = accuracy(logits, target, topk=(1, 5)) + xm.mark_step() top1 += acc1 top5 += acc5 n += images.size(0) @@ -74,20 +93,43 @@ def zero_shot_eval(model, data, epoch, args): logging.info('Starting zero-shot imagenet.') - logging.info('Building zero-shot classifier') - classifier = zero_shot_classifier(model, imagenet_classnames, openai_imagenet_template, args) - - logging.info('Using classifier') - results = {} - if 'imagenet-val' in data: - top1, top5 = run(model, classifier, data['imagenet-val'].dataloader, args) - results['imagenet-zeroshot-val-top1'] = top1 - results['imagenet-zeroshot-val-top5'] = top5 - if 'imagenet-v2' in data: - top1, top5 = run(model, classifier, data['imagenet-v2'].dataloader, args) - results['imagenetv2-zeroshot-val-top1'] = top1 - results['imagenetv2-zeroshot-val-top5'] = top5 - + # if MRL + if args.force_mrl_loss: + results = {} + for dim in args.mrl_dim_to_consider: + logging.info(f'Building zero-shot classifier dim-{dim}') + classifier = zero_shot_classifier(model, imagenet_classnames, openai_imagenet_template, dim, args) + + logging.info('Using classifier') + if 'imagenet-val' in data: + top1, top5 = run(model, classifier, data['imagenet-val'].dataloader, dim, args) + top1_name = f'imagenet-zeroshot-val-d{dim}-top1' + top5_name = f'imagenet-zeroshot-val-d{dim}-top5' + results[top1_name] = top1 + results[top5_name] = top5 + if 'imagenet-v2' in data: + top1, top5 = run(model, classifier, data['imagenet-v2'].dataloader, dim, args) + top1_name = f'imagenet-zeroshot-val-d{dim}-top1' + top5_name = f'imagenet-zeroshot-val-d{dim}-top5' + results[top1_name] = top1 + results[top5_name] = top5 + + # for other losses than MRL + else: + logging.info('Building zero-shot classifier') + classifier = zero_shot_classifier(model, imagenet_classnames, openai_imagenet_template, 0, args) + + logging.info('Using classifier') + results = {} + if 'imagenet-val' in data: + top1, top5 = run(model, classifier, data['imagenet-val'].dataloader, 0, args) + results['imagenet-zeroshot-val-top1'] = top1 + results['imagenet-zeroshot-val-top5'] = top5 + if 'imagenet-v2' in data: + top1, top5 = run(model, classifier, data['imagenet-v2'].dataloader, 0, args) + results['imagenetv2-zeroshot-val-top1'] = top1 + results['imagenetv2-zeroshot-val-top5'] = top5 + + # if not MRL logging.info('Finished zero-shot imagenet.') - return results