Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions scripts/clip.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
cd /mmfs1/gscratch/krishna/mayank/clip_clone/open_clip/src
torchrun --nproc_per_node 2 --master_port 3467 -m training.main \
--model "ViT-B-16" \
--train-data "/mmfs1/data/yfcc-tmp/cc_3m/train_shards/shard_{000000..003318}.tar" \
--imagenet-val "/mmfs1/data/yfcc-tmp/imagenet/val/" \
--dataset-type webdataset \
--precision amp \
--gather-with-grad \
--local-loss \
--batch-size 512 \
--accum-freq 1 \
--workers 4 \
--epochs 40 \
--warmup 4000 \
--zeroshot-frequency 2 \
--seed 0 \
--report-to 'wandb' \
--wandb-project-name "mrl_clip_training" \
--logs "/mmfs1/gscratch/krishna/mayank/clip_clone/open_clip/src/logs/mrl_clip" \
--name "clip_b512_accum_1_ep40_bugfixed"
25 changes: 25 additions & 0 deletions scripts/finetuning_original.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
cd /mmfs1/gscratch/krishna/mayank/clip_clone/open_clip/src
torchrun --nproc_per_node 2 --master_port 4556 -m training.main \
--model "ViT-B-16" \
--pretrained "laion400m_e32" \
--train-data "/mmfs1/data/yfcc-tmp/cc_3m/train_shards/shard_{000000..003318}.tar" \
--imagenet-val "/mmfs1/data/yfcc-tmp/imagenet/val/" \
--dataset-type webdataset \
--precision amp \
--gather-with-grad \
--local-loss \
--force_mrl_loss \
--mrl_loss_weights "1,1,1,1,1" \
--mrl_dim_to_consider "768,384,192,96,48" \
--accum-freq 1 \
--batch-size 512 \
--lr 1e-07 \
--workers 4 \
--epochs 10 \
--warmup 500 \
--zeroshot-frequency 1 \
--seed 0 \
--report-to 'wandb' \
--wandb-project-name "mrl_clip_training" \
--logs "/mmfs1/gscratch/krishna/mayank/clip_clone/open_clip/src/logs/mrl_clip" \
--name "ViT-B-16_liaon400m_e32_finetune_mrl_ep10_warmup_500_lr1e-07"
25 changes: 25 additions & 0 deletions scripts/finetuning_weighted.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
cd /mmfs1/gscratch/krishna/mayank/clip_clone/open_clip/src
torchrun --nproc_per_node 2 --master_port 4534 -m training.main \
--model "ViT-B-32" \
--pretrained "laion400m_e32" \
--train-data "/mmfs1/data/yfcc-tmp/cc_3m/train_shards/shard_{000000..003318}.tar" \
--imagenet-val "/mmfs1/data/yfcc-tmp/imagenet/val/" \
--dataset-type webdataset \
--precision amp \
--gather-with-grad \
--local-loss \
--force_mrl_loss \
--mrl_loss_weights "0.3,0.25,0.2,0.15,0.1" \
--mrl_dim_to_consider "768,384,192,96,48" \
--accum-freq 1 \
--batch-size 512 \
--lr 1e-07 \
--workers 4 \
--epochs 10 \
--warmup 500 \
--zeroshot-frequency 1 \
--seed 0 \
--report-to 'wandb' \
--wandb-project-name "mrl_clip_training" \
--logs "/mmfs1/gscratch/krishna/mayank/clip_clone/open_clip/src/logs/mrl_clip" \
--name "ViT-B-16_liaon400m_e32_finetune_mrl_ep10_warmup_500_wl030250201501_lr1e-07"
16 changes: 16 additions & 0 deletions scripts/launch_this.slurm
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/bin/bash

#SBATCH --account=krishna
#SBATCH --partition=gpu-a100
#SBATCH --job-name=MClip1
#SBATCH --error=./jobs/mrl_clip/job.err.%j
#SBATCH --output=./jobs/mrl_clip/job.run.%j
#SBATCH --time=10-00:00:00
#SBATCH --cpus-per-task=6
#SBATCH --gpus=2
#SBATCH --mem-per-gpu=80G
#SBATCH --nodes=1
#SBATCH [email protected]
#SBATCH --mail-type=ALL

bash mrl_clip.sh
26 changes: 26 additions & 0 deletions scripts/mrl_clip.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
cd ~/open_clip/src
#torchrun --nproc_per_node 2 --master_port 3233 -m training.main \
#lightning run model training/main.py \
python3 -m training.main \
--model "ViT-B-32" \
--train-data "/home/krmayank/data/medium/data/{00000000..00001919}.tar" \
--imagenet-val="/home/krmayank/data/imagenet/imagenet_val/"
--dataset-type webdataset \
--precision amp \
--gather-with-grad \
--local-loss \
--force_mrl_loss \
--mrl_loss_weights "1,1,1,1,1" \
--mrl_dim_to_consider "768,384,192,96,48" \
--batch-size 128 \
--accum-freq 1 \
--workers 4 \
--epochs 3 \
--warmup 4000 \
--zeroshot-frequency 1 \
--seed 1234 \
--num_train_samples 10000
#--report-to 'wandb' \
#--wandb-project-name "mrl_clip_training" \
#--logs "/mmfs1/gscratch/krishna/mayank/clip_clone/open_clip/src/logs/mrl_clip" \
#--name "mrl_clip_b512_accum_1_ep40_diffLogitScale_D082723_wl010150202503"
25 changes: 25 additions & 0 deletions scripts/resume.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
cd /mmfs1/gscratch/krishna/mayank/clip_clone/open_clip/src
torchrun --nproc_per_node 2 -m --master_port 3436 training.main \
--model "ViT-B-16" \
--train-data "/mmfs1/data/yfcc-tmp/cc_3m/train_shards/shard_{000000..003318}.tar" \
--imagenet-val "/mmfs1/data/yfcc-tmp/imagenet/val/" \
--dataset-type webdataset \
--precision amp \
--gather-with-grad \
--local-loss \
--accum-freq 1 \
--workers 4 \
--epochs 60 \
--warmup 1000 \
--zeroshot-frequency 1 \
--seed 0 \
--resume "/mmfs1/gscratch/krishna/mayank/clip_clone/open_clip/src/logs/mrl_clip/clip_b512_accum_1_ep40_bugfixed/checkpoints/epoch_40.pt" \
--report-to 'wandb' \
--wandb-project-name "mrl_clip_training" \
--logs "/mmfs1/gscratch/krishna/mayank/clip_clone/open_clip/src/logs/mrl_clip" \
--name "clip_b512_accum_1_ep40_resumefrom40"


# --force_mrl_loss \
# --mrl_loss_weights "1,1,1,1,1" \
# --mrl_dim_to_consider "768,384,192,96,48" \
35 changes: 35 additions & 0 deletions scripts/training_tpu.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#!/bin/bash

export PJRT_DEVICE=TPU
#export XLA_IR_DEBUG=1
#export XLA_METRICS_FILE=1
#sudo kill -9 $(sudo lsof -w /dev/accel0 | awk 'NR>1 {print $2}' | uniq)

python -c "import os; os.environ.pop('LD_PRELOAD', None)"

cd ~/open_clip/src/
python3 -m training.main \
--model "ViT-B-32" \
--train-data "/home/krmayank/data/medium/data/{00000000..00001919}.tar" \
--imagenet-val="/home/krmayank/data/imagenet/imagenet_val/" \
--dataset-type webdataset \
--precision amp \
--gather-with-grad \
--local-loss \
--force_mrl_loss \
--mrl_loss_weights "1,1,1,1,1" \
--mrl_dim_to_consider "768,384,192,96,48" \
--batch-size 128 \
--accum-freq 1 \
--workers 4 \
--epochs 3 \
--warmup 4 \
--zeroshot-frequency 1 \
--seed 1234 \
--gather-with-grad \
--train-num-samples 10000 \
--val-num-samples 10000
# --report-to 'wandb' \
# --wandb-project-name "mrl_clip_training"
# --val-data "/home/krmayank/data/medium/data/{00000000..00000000}.tar" \
# --imagenet-val="/home/krmayank/data/imagenet/imagenet_val/" \
28 changes: 25 additions & 3 deletions src/open_clip/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
resize_pos_embed, get_cast_dtype
from .coca_model import CoCa
from .loss import ClipLoss, DistillClipLoss, CoCaLoss
from .loss import ClipLoss, DistillClipLoss, CoCaLoss, MRLClipLoss
from .openai import load_openai_model
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model, download_pretrained_from_hf
from .transform import image_transform, AugmentationCfg
Expand Down Expand Up @@ -119,6 +119,8 @@ def create_model(
cache_dir: Optional[str] = None,
output_dict: Optional[bool] = None,
require_pretrained: bool = False,
use_mrl: bool = False,
mrl_dim: int = 0
):
has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
if has_hf_hub_prefix:
Expand Down Expand Up @@ -191,7 +193,7 @@ def create_model(
else:
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
else:
model = CLIP(**model_cfg, cast_dtype=cast_dtype)
model = CLIP(**model_cfg, cast_dtype=cast_dtype, use_mrl=use_mrl, mrl_dim=mrl_dim) # intialise with mrl based config if use_mrl

pretrained_loaded = False
if pretrained:
Expand All @@ -204,7 +206,10 @@ def create_model(

if checkpoint_path:
logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
load_checkpoint(model, checkpoint_path)
if use_mrl:
load_checkpoint(model, checkpoint_path, strict=False)
else:
load_checkpoint(model, checkpoint_path)
else:
error_str = (
f'Pretrained weights ({pretrained}) not found for model {model_name}.'
Expand Down Expand Up @@ -261,6 +266,17 @@ def create_loss(args):
world_size=args.world_size,
use_horovod=args.horovod,
)
elif args.force_mrl_loss:
return MRLClipLoss(
local_loss=args.local_loss,
gather_with_grad=args.gather_with_grad,
cache_labels=True,
rank=args.rank,
world_size=args.world_size,
use_horovod=args.horovod,
mrl_loss_weights=args.mrl_loss_weights,
dim_to_consider=args.mrl_dim_to_consider
)
return ClipLoss(
local_loss=args.local_loss,
gather_with_grad=args.gather_with_grad,
Expand Down Expand Up @@ -288,6 +304,8 @@ def create_model_and_transforms(
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
cache_dir: Optional[str] = None,
output_dict: Optional[bool] = None,
use_mrl: Optional[bool] = False,
mrl_dim: int = 0
):
model = create_model(
model_name,
Expand All @@ -303,6 +321,8 @@ def create_model_and_transforms(
pretrained_hf=pretrained_hf,
cache_dir=cache_dir,
output_dict=output_dict,
use_mrl=use_mrl,
mrl_dim = mrl_dim
)

image_mean = image_mean or getattr(model.visual, 'image_mean', None)
Expand Down Expand Up @@ -337,6 +357,7 @@ def create_model_from_pretrained(
image_mean: Optional[Tuple[float, ...]] = None,
image_std: Optional[Tuple[float, ...]] = None,
cache_dir: Optional[str] = None,
use_mrl = False
):
model = create_model(
model_name,
Expand All @@ -349,6 +370,7 @@ def create_model_from_pretrained(
force_image_size=force_image_size,
cache_dir=cache_dir,
require_pretrained=True,
use_mrl=use_mrl
)

if not return_transform:
Expand Down
54 changes: 54 additions & 0 deletions src/open_clip/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
except ImportError:
hvd = None

try:
import torch_xla.core.xla_model as xm
except ImportError:
xm = None


def gather_features(
image_features,
Expand Down Expand Up @@ -210,3 +215,52 @@ def forward(
return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss}

return contrastive_loss, distill_loss

class MRLClipLoss(ClipLoss):
def __init__(self,
local_loss=False,
gather_with_grad=False,
cache_labels=False,
rank=0,
world_size=1,
use_horovod=False,
mrl_loss_weights = None,
dim_to_consider=None):
super().__init__(local_loss,
gather_with_grad,
cache_labels,
rank,
world_size,
use_horovod)
self.mrl_loss_weights = mrl_loss_weights
self.dim_to_consider = dim_to_consider

def normalize(self, features, dim=-1):
if xm.xla_device():
norm = xm.all_reduce("sum", features ** 2)
norm = torch.sqrt(norm)
features = features / norm
return features
return F.normalize(features, dim=dim)

def forward(self, image_features, text_features, logit_scale, output_dict=False):
# print("Inside forward of MRL CLIP loss",len(self.mrl_loss_weights), self.mrl_loss_weights )

assert len(self.mrl_loss_weights) == len(self.dim_to_consider), "number of elements in loss weights and dim_to_consider should be same"

# dim_to_consider = [768, 384, 192, 96, 48]
# total_loss = 0
loss_list = []
for idx, dim in enumerate(self.dim_to_consider):
# img = self.normalize(image_features[:,:dim], dim=-1) # slice and normalize
# txt = self.normalize(text_features[:,:dim], dim=-1) # slice and normalize
img = F.normalize(image_features[:,:dim], dim=-1) # slice and normalize
txt = F.normalize(text_features[:,:dim], dim=-1) # slice and normalize
loss = super().forward(image_features=img, text_features=txt, logit_scale=logit_scale[idx])
# total_loss += self.mrl_loss_weights[idx] * loss
loss_list.append(self.mrl_loss_weights[idx] * loss)

if output_dict:
return {f"mrl_clip_loss_{key}": value for key, value in zip(self.dim_to_consider, loss_list)}

return loss_list
Loading