diff --git a/dinov2/configs/train/vits14_pyramid.yaml b/dinov2/configs/train/vits14_pyramid.yaml new file mode 100644 index 000000000..26526fe2d --- /dev/null +++ b/dinov2/configs/train/vits14_pyramid.yaml @@ -0,0 +1,68 @@ +dino: + head_n_prototypes: 65536 + head_bottleneck_dim: 256 + head_hidden_dim: 2048 + head_nlayers: 3 + loss_weight: 1.0 + koleo_loss_weight: 0.1 +ibot: + loss_weight: 1.0 + mask_sample_probability: 0.5 + mask_ratio_min_max: + - 0.1 + - 0.5 + separate_head: true + head_n_prototypes: 65536 + head_bottleneck_dim: 256 + head_hidden_dim: 2048 + head_nlayers: 3 +train: + batch_size_per_gpu: 64 + dataset_path: HuggingFace:root=tsbpp/fall2025_deeplearning:split=train + centering: sinkhorn_knopp + output_dir: /gpfs/data/shenlab/aj4718/dinov2/logs/vits14_pyramid + OFFICIAL_EPOCH_LENGTH: 7813 + num_workers: 4 + saveckp_freq: 1 +student: + arch: vit_small + patch_size: 16 + drop_path_rate: 0.4 + layerscale: 1.0e-05 + drop_path_uniform: true + pretrained_weights: '' + ffn_layer: mlp + block_chunks: 0 +teacher: + momentum_teacher: 0.994 + final_momentum_teacher: 1.0 + warmup_teacher_temp: 0.04 + teacher_temp: 0.07 + warmup_teacher_temp_epochs: 30 +optim: + epochs: 300 + weight_decay: 0.04 + weight_decay_end: 0.4 + base_lr: 0.002 + lr: 0.0 # will be set by apply_scaling_rules_to_cfg + warmup_epochs: 10 + min_lr: 1.0e-06 + clip_grad: 3.0 + freeze_last_layer_epochs: 1 + scaling_rule: sqrt_wrt_1024 + patch_embed_lr_mult: 0.2 + layerwise_decay: 0.9 + adamw_beta1: 0.9 + adamw_beta2: 0.999 +crops: + global_crops_scale: + - 0.32 + - 1.0 + local_crops_number: 8 + local_crops_scale: + - 0.05 + - 0.32 + global_crops_size: 96 + local_crops_size: 48 +evaluation: + eval_period_iterations: 10000 diff --git a/dinov2/data/datasets/__init__.py b/dinov2/data/datasets/__init__.py index 5550fdc5c..8e6acca90 100644 --- a/dinov2/data/datasets/__init__.py +++ b/dinov2/data/datasets/__init__.py @@ -5,3 +5,4 @@ from .image_net import ImageNet from .image_net_22k import ImageNet22k +from .huggingface import HuggingFaceDataset diff --git a/dinov2/data/datasets/huggingface.py b/dinov2/data/datasets/huggingface.py new file mode 100644 index 000000000..c1e173ca9 --- /dev/null +++ b/dinov2/data/datasets/huggingface.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from typing import Any, Callable, Optional, Tuple +import logging + +from datasets import load_dataset +from PIL import Image +import torch +from torchvision.datasets import VisionDataset + +logger = logging.getLogger("dinov2") + +class HuggingFaceDataset(VisionDataset): + def __init__( + self, + root: str, + split: str = "train", + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + ) -> None: + super().__init__(root, transform=transform, target_transform=target_transform) + self.split = split + self.dataset_name = root # Reusing root as dataset name for consistency with other datasets + + logger.info(f"Loading HuggingFace dataset: {self.dataset_name}, split: {self.split}") + self.dataset = load_dataset(self.dataset_name, split=self.split) + logger.info(f"Loaded {len(self.dataset)} samples from {self.dataset_name}") + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + item = self.dataset[index] + + # Assumes 'image' column. Modify if needed. + image = item['image'] + + if image.mode != 'RGB': + image = image.convert('RGB') + + target = 0 # Dummy target for SSL + + if self.transform is not None: + image = self.transform(image) + + if self.target_transform is not None: + target = self.target_transform(target) + + return image, target + + def __len__(self) -> int: + return len(self.dataset) diff --git a/dinov2/data/loaders.py b/dinov2/data/loaders.py index d6a2f0210..687e10e0d 100644 --- a/dinov2/data/loaders.py +++ b/dinov2/data/loaders.py @@ -10,7 +10,7 @@ import torch from torch.utils.data import Sampler -from .datasets import ImageNet, ImageNet22k +from .datasets import ImageNet, ImageNet22k, HuggingFaceDataset from .samplers import EpochSampler, InfiniteSampler, ShardedInfiniteSampler @@ -58,6 +58,8 @@ def _parse_dataset_str(dataset_str: str): kwargs["split"] = ImageNet.Split[kwargs["split"]] elif name == "ImageNet22k": class_ = ImageNet22k + elif name == "HuggingFace": + class_ = HuggingFaceDataset else: raise ValueError(f'Unsupported dataset "{name}"') diff --git a/dinov2/train/ssl_meta_arch_pyramid.py b/dinov2/train/ssl_meta_arch_pyramid.py new file mode 100644 index 000000000..6d17abd46 --- /dev/null +++ b/dinov2/train/ssl_meta_arch_pyramid.py @@ -0,0 +1,509 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from functools import partial +import logging + +import torch +from torch import nn +import torch.nn.functional as F + +from dinov2.loss import DINOLoss, iBOTPatchLoss, KoLeoLoss +from dinov2.models import build_model_from_cfg +from dinov2.layers import DINOHead, Mlp +from dinov2.utils.utils import has_batchnorms +from dinov2.utils.param_groups import get_params_groups_with_decay, fuse_params_groups +from dinov2.fsdp import get_fsdp_wrapper, ShardedGradScaler, get_fsdp_modules, reshard_fsdp_model + +from dinov2.models.vision_transformer import BlockChunk + + +try: + from xformers.ops import fmha +except ImportError: + raise AssertionError("xFormers is required for training") + + +logger = logging.getLogger("dinov2") + + +class PyramidSSLMetaArch(nn.Module): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self.fp16_scaler = ShardedGradScaler() if cfg.compute_precision.grad_scaler else None + + student_model_dict = dict() + teacher_model_dict = dict() + + student_backbone, teacher_backbone, embed_dim = build_model_from_cfg(cfg) + student_model_dict["backbone"] = student_backbone + teacher_model_dict["backbone"] = teacher_backbone + logger.info(f"OPTIONS -- architecture : embed_dim: {embed_dim}") + + if cfg.student.pretrained_weights: + chkpt = torch.load(cfg.student.pretrained_weights) + logger.info(f"OPTIONS -- pretrained weights: loading from {cfg.student.pretrained_weights}") + student_backbone.load_state_dict(chkpt["model"], strict=False) + + self.embed_dim = embed_dim + self.dino_out_dim = cfg.dino.head_n_prototypes + + self.do_dino = cfg.dino.loss_weight > 0 + self.do_koleo = cfg.dino.koleo_loss_weight > 0 + self.do_ibot = cfg.ibot.loss_weight > 0 + self.ibot_separate_head = cfg.ibot.separate_head + + logger.info("OPTIONS -- DINO") + if self.do_dino: + logger.info(f"OPTIONS -- DINO -- loss_weight: {cfg.dino.loss_weight}") + logger.info(f"OPTIONS -- DINO -- head_n_prototypes: {cfg.dino.head_n_prototypes}") + logger.info(f"OPTIONS -- DINO -- head_bottleneck_dim: {cfg.dino.head_bottleneck_dim}") + logger.info(f"OPTIONS -- DINO -- head_hidden_dim: {cfg.dino.head_hidden_dim}") + self.dino_loss_weight = cfg.dino.loss_weight + dino_head = partial( + DINOHead, + in_dim=embed_dim, + out_dim=cfg.dino.head_n_prototypes, + hidden_dim=cfg.dino.head_hidden_dim, + bottleneck_dim=cfg.dino.head_bottleneck_dim, + nlayers=cfg.dino.head_nlayers, + ) + self.dino_loss = DINOLoss(self.dino_out_dim) + if self.do_koleo: + logger.info("OPTIONS -- DINO -- applying KOLEO regularization") + self.koleo_loss = KoLeoLoss() + + else: + logger.info("OPTIONS -- DINO -- not using DINO") + + if self.do_dino or self.do_ibot: + student_model_dict["dino_head"] = dino_head() + teacher_model_dict["dino_head"] = dino_head() + + logger.info("OPTIONS -- IBOT") + logger.info(f"OPTIONS -- IBOT -- loss_weight: {cfg.ibot.loss_weight}") + logger.info(f"OPTIONS -- IBOT masking -- ibot_mask_ratio_tuple: {cfg.ibot.mask_ratio_min_max}") + logger.info(f"OPTIONS -- IBOT masking -- ibot_mask_sample_probability: {cfg.ibot.mask_sample_probability}") + if self.do_ibot: + self.ibot_loss_weight = cfg.ibot.loss_weight + assert max(cfg.ibot.mask_ratio_min_max) > 0, "please provide a positive mask ratio tuple for ibot" + assert cfg.ibot.mask_sample_probability > 0, "please provide a positive mask probability for ibot" + self.ibot_out_dim = cfg.ibot.head_n_prototypes if self.ibot_separate_head else cfg.dino.head_n_prototypes + self.ibot_patch_loss = iBOTPatchLoss(self.ibot_out_dim) + if self.ibot_separate_head: + logger.info(f"OPTIONS -- IBOT -- loss_weight: {cfg.ibot.loss_weight}") + logger.info(f"OPTIONS -- IBOT -- head_n_prototypes: {cfg.ibot.head_n_prototypes}") + logger.info(f"OPTIONS -- IBOT -- head_bottleneck_dim: {cfg.ibot.head_bottleneck_dim}") + logger.info(f"OPTIONS -- IBOT -- head_hidden_dim: {cfg.ibot.head_hidden_dim}") + ibot_head = partial( + DINOHead, + in_dim=embed_dim, + out_dim=cfg.ibot.head_n_prototypes, + hidden_dim=cfg.ibot.head_hidden_dim, + bottleneck_dim=cfg.ibot.head_bottleneck_dim, + nlayers=cfg.ibot.head_nlayers, + ) + student_model_dict["ibot_head"] = ibot_head() + teacher_model_dict["ibot_head"] = ibot_head() + else: + logger.info("OPTIONS -- IBOT -- head shared with DINO") + + # Pyramidal Feature Distillation Setup + self.pyramid_layers = [2, 5, 8] # Layers 3, 6, 9 (0-indexed) + logger.info(f"OPTIONS -- PYRAMID -- layers: {self.pyramid_layers}") + + # Create projectors for student intermediate layers + # Simple MLP: Linear -> GELU -> Linear (preserving dim) + self.student_projectors = nn.ModuleList([ + Mlp(in_features=embed_dim, hidden_features=embed_dim, out_features=embed_dim, act_layer=nn.GELU) + for _ in self.pyramid_layers + ]) + student_model_dict["projectors"] = self.student_projectors + + # Add projectors to student model dict + + # Register hooks for intermediate layers + self.student_intermediates = {} + self.teacher_intermediates = {} + + def get_activation(name, storage): + def hook(model, input, output): + # output might be a list (NestedTensorBlock) or tensor + # We store it in the storage dict + # If it's a list, we might need to process it later or just store the list + storage[name] = output + return hook + + # Register hooks on student backbone blocks + # Note: self.student.backbone.blocks is a ModuleList + for i in self.pyramid_layers: + student_model_dict["backbone"].blocks[i].register_forward_hook(get_activation(i, self.student_intermediates)) + + # Register hooks on teacher backbone blocks + for i in self.pyramid_layers: + teacher_model_dict["backbone"].blocks[i].register_forward_hook(get_activation(i, self.teacher_intermediates)) + + self.need_to_synchronize_fsdp_streams = True + + self.student = nn.ModuleDict(student_model_dict) + self.teacher = nn.ModuleDict(teacher_model_dict) + + # there is no backpropagation through the teacher, so no need for gradients + for p in self.teacher.parameters(): + p.requires_grad = False + logger.info(f"Student and Teacher are built: they are both {cfg.student.arch} network.") + + def forward(self, inputs): + raise NotImplementedError + + def backprop_loss(self, loss): + if self.fp16_scaler is not None: + self.fp16_scaler.scale(loss).backward() + else: + loss.backward() + + def forward_backward(self, images, teacher_temp): + # Clear intermediates before forward pass + self.student_intermediates.clear() + self.teacher_intermediates.clear() + + n_global_crops = 2 + assert n_global_crops == 2 + n_local_crops = self.cfg.crops.local_crops_number + + global_crops = images["collated_global_crops"].cuda(non_blocking=True) + local_crops = images["collated_local_crops"].cuda(non_blocking=True) + + masks = images["collated_masks"].cuda(non_blocking=True) + mask_indices_list = images["mask_indices_list"].cuda(non_blocking=True) + n_masked_patches_tensor = images["n_masked_patches"].cuda(non_blocking=True) + n_masked_patches = mask_indices_list.shape[0] + upperbound = images["upperbound"] + masks_weight = images["masks_weight"].cuda(non_blocking=True) + + n_local_crops_loss_terms = max(n_local_crops * n_global_crops, 1) + n_global_crops_loss_terms = (n_global_crops - 1) * n_global_crops + + do_dino = self.do_dino + do_ibot = self.do_ibot + + # loss scales + ibot_loss_scale = 1.0 / n_global_crops + + # teacher output + @torch.no_grad() + def get_teacher_output(): + x, n_global_crops_teacher = global_crops, n_global_crops + + # No output_indices needed, hooks will capture + teacher_backbone_output_dict = self.teacher.backbone(x, is_training=True) + + teacher_cls_tokens = teacher_backbone_output_dict["x_norm_clstoken"] + teacher_cls_tokens = teacher_cls_tokens.chunk(n_global_crops_teacher) + # watch out: these are chunked and cat'd in reverse so A is matched to B in the global crops dino loss + teacher_cls_tokens = torch.cat((teacher_cls_tokens[1], teacher_cls_tokens[0])) + ibot_teacher_patch_tokens = teacher_backbone_output_dict["x_norm_patchtokens"] + _dim = ibot_teacher_patch_tokens.shape[-1] + n_cls_tokens = teacher_cls_tokens.shape[0] + + if do_ibot and not self.ibot_separate_head: + buffer_tensor_teacher = ibot_teacher_patch_tokens.new_zeros(upperbound + n_cls_tokens, _dim) + buffer_tensor_teacher[:n_cls_tokens].copy_(teacher_cls_tokens) + torch.index_select( + ibot_teacher_patch_tokens.flatten(0, 1), + dim=0, + index=mask_indices_list, + out=buffer_tensor_teacher[n_cls_tokens : n_cls_tokens + n_masked_patches], + ) + tokens_after_head = self.teacher.dino_head(buffer_tensor_teacher) + teacher_cls_tokens_after_head = tokens_after_head[:n_cls_tokens] + masked_teacher_patch_tokens_after_head = tokens_after_head[ + n_cls_tokens : n_cls_tokens + n_masked_patches + ] + elif do_ibot and self.ibot_separate_head: + buffer_tensor_teacher = ibot_teacher_patch_tokens.new_zeros(upperbound, _dim) + torch.index_select( + ibot_teacher_patch_tokens.flatten(0, 1), + dim=0, + index=mask_indices_list, + out=buffer_tensor_teacher[:n_masked_patches], + ) + teacher_cls_tokens_after_head = self.teacher.dino_head(teacher_cls_tokens) + masked_teacher_patch_tokens_after_head = self.teacher.ibot_head(buffer_tensor_teacher)[ + :n_masked_patches + ] + else: + teacher_cls_tokens_after_head = self.teacher.dino_head(teacher_cls_tokens) + masked_teacher_ibot_softmaxed_centered = None + + if self.cfg.train.centering == "centering": + teacher_dino_softmaxed_centered_list = self.dino_loss.softmax_center_teacher( + teacher_cls_tokens_after_head, teacher_temp=teacher_temp + ).view(n_global_crops_teacher, -1, *teacher_cls_tokens_after_head.shape[1:]) + self.dino_loss.update_center(teacher_cls_tokens_after_head) + if do_ibot: + masked_teacher_patch_tokens_after_head = masked_teacher_patch_tokens_after_head.unsqueeze(0) + masked_teacher_ibot_softmaxed_centered = self.ibot_patch_loss.softmax_center_teacher( + masked_teacher_patch_tokens_after_head[:, :n_masked_patches], teacher_temp=teacher_temp + ) + masked_teacher_ibot_softmaxed_centered = masked_teacher_ibot_softmaxed_centered.squeeze(0) + self.ibot_patch_loss.update_center(masked_teacher_patch_tokens_after_head[:n_masked_patches]) + + elif self.cfg.train.centering == "sinkhorn_knopp": + teacher_dino_softmaxed_centered_list = self.dino_loss.sinkhorn_knopp_teacher( + teacher_cls_tokens_after_head, teacher_temp=teacher_temp + ).view(n_global_crops_teacher, -1, *teacher_cls_tokens_after_head.shape[1:]) + + if do_ibot: + masked_teacher_ibot_softmaxed_centered = self.ibot_patch_loss.sinkhorn_knopp_teacher( + masked_teacher_patch_tokens_after_head, + teacher_temp=teacher_temp, + n_masked_patches_tensor=n_masked_patches_tensor, + ) + + else: + raise NotImplementedError + + # Return intermediates from hooks (clone them to be safe) + # Teacher intermediates are tensors (since input was tensor) + # We clone to detach from graph (though no_grad already does that) and ensure safety + teacher_intermediates_cloned = {k: v.clone() for k, v in self.teacher_intermediates.items()} + return teacher_dino_softmaxed_centered_list, masked_teacher_ibot_softmaxed_centered, teacher_intermediates_cloned + + teacher_dino_softmaxed_centered_list, masked_teacher_ibot_softmaxed_centered, teacher_intermediate = get_teacher_output() + reshard_fsdp_model(self.teacher) + + loss_dict = {} + + loss_accumulator = 0 # for backprop + + # No output_indices needed, hooks will capture + student_global_backbone_output_dict, student_local_backbone_output_dict = self.student.backbone( + [global_crops, local_crops], masks=[masks, None], is_training=True + ) + + inputs_for_student_head_list = [] + + # 1a: local crops cls tokens + student_local_cls_tokens = student_local_backbone_output_dict["x_norm_clstoken"] + inputs_for_student_head_list.append(student_local_cls_tokens.unsqueeze(0)) + + # 1b: global crops cls tokens + student_global_cls_tokens = student_global_backbone_output_dict["x_norm_clstoken"] + inputs_for_student_head_list.append(student_global_cls_tokens.unsqueeze(0)) + + # 1c: global crops patch tokens + if do_ibot: + _dim = student_global_backbone_output_dict["x_norm_clstoken"].shape[-1] + ibot_student_patch_tokens = student_global_backbone_output_dict["x_norm_patchtokens"] + buffer_tensor_patch_tokens = ibot_student_patch_tokens.new_zeros(upperbound, _dim) + buffer_tensor_patch_tokens[:n_masked_patches].copy_( + torch.index_select(ibot_student_patch_tokens.flatten(0, 1), dim=0, index=mask_indices_list) + ) + if not self.ibot_separate_head: + inputs_for_student_head_list.append(buffer_tensor_patch_tokens.unsqueeze(0)) + else: + student_global_masked_patch_tokens_after_head = self.student.ibot_head(buffer_tensor_patch_tokens)[ + :n_masked_patches + ] + + # 2: run + _attn_bias, cat_inputs = fmha.BlockDiagonalMask.from_tensor_list(inputs_for_student_head_list) + outputs_list = _attn_bias.split(self.student.dino_head(cat_inputs)) + + # 3a: local crops cls tokens + student_local_cls_tokens_after_head = outputs_list.pop(0).squeeze(0) + + # 3b: global crops cls tokens + student_global_cls_tokens_after_head = outputs_list.pop(0).squeeze(0) + + # 3c: global crops patch tokens + if do_ibot and not self.ibot_separate_head: + student_global_masked_patch_tokens_after_head = outputs_list.pop(0).squeeze(0)[:n_masked_patches] + + if n_local_crops > 0: + dino_local_crops_loss = self.dino_loss( + student_output_list=student_local_cls_tokens_after_head.chunk(n_local_crops), + teacher_out_softmaxed_centered_list=teacher_dino_softmaxed_centered_list, + ) / (n_global_crops_loss_terms + n_local_crops_loss_terms) + + # store for display + loss_dict["dino_local_crops_loss"] = dino_local_crops_loss + + # accumulate loss + loss_accumulator += self.dino_loss_weight * dino_local_crops_loss + + # process global crops + loss_scales = 2 # this is here since we process global crops together + + if do_dino: + # compute loss + dino_global_crops_loss = ( + self.dino_loss( + student_output_list=[student_global_cls_tokens_after_head], + teacher_out_softmaxed_centered_list=[ + teacher_dino_softmaxed_centered_list.flatten(0, 1) + ], # these were chunked and stacked in reverse so A is matched to B + ) + * loss_scales + / (n_global_crops_loss_terms + n_local_crops_loss_terms) + ) + + loss_dict["dino_global_crops_loss"] = dino_global_crops_loss + + # accumulate loss + loss_accumulator += self.dino_loss_weight * dino_global_crops_loss + + student_cls_tokens = student_global_cls_tokens + + if self.do_koleo: + koleo_loss = self.cfg.dino.koleo_loss_weight * sum( + self.koleo_loss(p) for p in student_cls_tokens.chunk(2) + ) # we don't apply koleo loss between cls tokens of a same image + loss_accumulator += koleo_loss + loss_dict["koleo_loss"] = ( + koleo_loss / loss_scales + ) # this is to display the same losses as before but we can remove eventually + + if do_ibot: + # compute loss + ibot_patch_loss = ( + self.ibot_patch_loss.forward_masked( + student_global_masked_patch_tokens_after_head, + masked_teacher_ibot_softmaxed_centered, + student_masks_flat=masks, + n_masked_patches=n_masked_patches, + masks_weight=masks_weight, + ) + * loss_scales + * ibot_loss_scale + ) + + # store for display + loss_dict["ibot_loss"] = ibot_patch_loss / 2 + + # accumulate loss + loss_accumulator += self.ibot_loss_weight * ibot_patch_loss + + # --- Pyramidal Feature Distillation Loss --- + pyramid_loss = 0.0 + + # Use captured intermediates + # Student intermediates might be lists (global + local crops) + # We only want global crops for distillation (usually first 2 items in the list if it's a list) + # But wait, NestedTensorBlock returns a list of tensors corresponding to input list. + # Student input was [global_crops, local_crops]. + # So output is [global_crops_out, local_crops_out]. + # Global crops are the first element (which is itself a tensor of batch size 2*B if collated? No, global_crops is one tensor). + # Actually, input to student.backbone was [global_crops, local_crops]. + # So output of block is [global_crops_feat, local_crops_feat]. + # We want global_crops_feat. + + if self.student_intermediates and teacher_intermediate: + for i, layer_idx in enumerate(self.pyramid_layers): + # Get student features + s_out = self.student_intermediates[layer_idx] + if isinstance(s_out, list): + s_feat = s_out[0] # Global crops features + else: + s_feat = s_out # Should not happen given input structure but safe to handle + + # Project student features + s_feat_proj = self.student["projectors"][i](s_feat) + + # Get teacher features (frozen) + t_feat = teacher_intermediate[layer_idx] + + # Compute MSE loss (normalized by feature dim) + layer_loss = F.mse_loss(s_feat_proj, t_feat) + pyramid_loss += layer_loss + + # Add to total loss with weight lambda=1.0 (can be configurable) + pyramid_loss_weight = 1.0 + loss_dict["pyramid_loss"] = pyramid_loss + loss_accumulator += pyramid_loss_weight * pyramid_loss + + self.backprop_loss(loss_accumulator) + + self.fsdp_synchronize_streams() + + return loss_dict + + def fsdp_synchronize_streams(self): + if self.need_to_synchronize_fsdp_streams: + torch.cuda.synchronize() + try: + self.student.dino_head._streams = ( + self.teacher.dino_head._streams + ) = self.student.backbone._streams = self.teacher.backbone._streams + except AttributeError: + pass + self.need_to_synchronize_fsdp_streams = False + + def update_teacher(self, m): + student_param_list = [] + teacher_param_list = [] + with torch.no_grad(): + for k in self.student.keys(): + if k == "projectors": + continue + for ms, mt in zip(get_fsdp_modules(self.student[k]), get_fsdp_modules(self.teacher[k])): + student_param_list += ms.params + teacher_param_list += mt.params + torch._foreach_mul_(teacher_param_list, m) + torch._foreach_add_(teacher_param_list, student_param_list, alpha=1 - m) + + def train(self): + super().train() + self.teacher.eval() + + def get_maybe_fused_params_for_submodel(self, m): + params_groups = get_params_groups_with_decay( + model=m, + lr_decay_rate=self.cfg.optim.layerwise_decay, + patch_embed_lr_mult=self.cfg.optim.patch_embed_lr_mult, + ) + fused_params_groups = fuse_params_groups(params_groups) + logger.info("fusing param groups") + + for g in fused_params_groups: + g["foreach"] = True + return fused_params_groups + + def get_params_groups(self): + all_params_groups = [] + for m in self.student.values(): + all_params_groups += self.get_maybe_fused_params_for_submodel(m) + return all_params_groups + + def prepare_for_distributed_training(self): + logger.info("DISTRIBUTED FSDP -- preparing model for distributed training") + if has_batchnorms(self.student): + raise NotImplementedError + # below will synchronize all student subnetworks across gpus: + for k, v in self.student.items(): + # For projectors, we need to initialize teacher projectors? + # Actually, teacher doesn't have projectors in this design (it's distillation FROM teacher backbone TO student backbone+projector) + # But we added 'projectors' to student_model_dict. + # We should NOT add them to teacher_model_dict if teacher doesn't use them. + # In __init__, we did NOT add 'projectors' to teacher_model_dict. + # So self.teacher[k] will fail if k='projectors'. + + if k == "projectors": + # Projectors are only on student. + student_model_cfg = self.cfg.compute_precision.student["backbone"] # Use backbone config for now + # Wrap each MLP individually so we can call them independently + wrapped_projectors = [] + for projector in self.student[k]: + wrapped_projectors.append(get_fsdp_wrapper(student_model_cfg)(projector)) + self.student[k] = nn.ModuleList(wrapped_projectors) + continue + + self.teacher[k].load_state_dict(self.student[k].state_dict()) + student_model_cfg = self.cfg.compute_precision.student[k] + self.student[k] = get_fsdp_wrapper(student_model_cfg, modules_to_wrap={BlockChunk})(self.student[k]) + teacher_model_cfg = self.cfg.compute_precision.teacher[k] + self.teacher[k] = get_fsdp_wrapper(teacher_model_cfg, modules_to_wrap={BlockChunk})(self.teacher[k]) diff --git a/dinov2/train/train.py b/dinov2/train/train.py index 473b8d014..ecd034bfb 100644 --- a/dinov2/train/train.py +++ b/dinov2/train/train.py @@ -8,6 +8,8 @@ import math import os from functools import partial +import wandb +from omegaconf import OmegaConf from fvcore.common.checkpoint import PeriodicCheckpointer import torch @@ -21,6 +23,8 @@ from dinov2.utils.utils import CosineScheduler from dinov2.train.ssl_meta_arch import SSLMetaArch +from PIL import Image # <-- add this line +from dinov2.train.ssl_meta_arch_pyramid import PyramidSSLMetaArch torch.backends.cuda.matmul.allow_tf32 = True # PyTorch 1.12 sets this to False by default @@ -54,6 +58,16 @@ def get_args_parser(add_help: bool = True): type=str, help="Output directory to save logs and checkpoints", ) + parser.add_argument("--enable-wandb", action="store_true", help="Enable WandB logging") + parser.add_argument("--wandb-project", type=str, default="dinov2", help="WandB project name") + parser.add_argument( + "--wandb-entity", + type=str, + default="sd6701-new-york-university", # <- was "dinov2-traning-1" + help="WandB entity (team or username). If None, use your default account.", + ) + parser.add_argument("--wandb-name", type=str, default="dinov2-traning-sam", help="WandB run name") + parser.add_argument("--wandb-api-key", type=str, default="14abcf8b33d9a7f066dd1988891a00fec55f4030", help="WandB API key") return parser @@ -131,11 +145,22 @@ def do_test(cfg, model, iteration): torch.save({"teacher": new_state_dict}, teacher_ckp_path) -def do_train(cfg, model, resume=False): +def do_train(cfg, model, args, resume=False): model.train() inputs_dtype = torch.half fp16_scaler = model.fp16_scaler # for mixed precision training + if args.enable_wandb and distributed.is_main_process(): + if args.wandb_api_key: + wandb.login(key=args.wandb_api_key) + wandb.init( + project=args.wandb_project, + entity=args.wandb_entity, + name=args.wandb_name, + config=OmegaConf.to_container(cfg, resolve=True), + dir=args.output_dir, + ) + # setup optimizer optimizer = build_optimizer(cfg, model.get_params_groups()) @@ -161,6 +186,13 @@ def do_train(cfg, model, resume=False): max_iter=max_iter, max_to_keep=3, ) + # periodic_checkpointer = PeriodicCheckpointer( + # checkpointer, + # # save once per "epoch" + # period=OFFICIAL_EPOCH_LENGTH, + # max_iter=max_iter, + # max_to_keep=3, + # ) # setup data preprocessing @@ -172,14 +204,23 @@ def do_train(cfg, model, resume=False): max_num_patches=0.5 * img_size // patch_size * img_size // patch_size, ) - data_transform = DataAugmentationDINO( + base_transform = DataAugmentationDINO( cfg.crops.global_crops_scale, cfg.crops.local_crops_scale, cfg.crops.local_crops_number, - global_crops_size=cfg.crops.global_crops_size, - local_crops_size=cfg.crops.local_crops_size, + global_crops_size=cfg.crops.global_crops_size, # 224 + local_crops_size=cfg.crops.local_crops_size, # 112 ) + # Wrapper to first resize raw CC3M image (96x96) to 224x224 + def data_transform(img): + # img should be a PIL.Image from CC3MDataset + if not isinstance(img, Image.Image): + # just in case, convert tensor/array to PIL + img = Image.fromarray(np.array(img)) + + return base_transform(img) + collate_fn = partial( collate_data_and_cast, mask_ratio_tuple=cfg.ibot.mask_ratio_min_max, @@ -282,6 +323,18 @@ def do_train(cfg, model, resume=False): metric_logger.update(current_batch_size=current_batch_size) metric_logger.update(total_loss=losses_reduced, **loss_dict_reduced) + if args.enable_wandb and distributed.is_main_process(): + wandb.log({ + "train/lr": lr, + "train/wd": wd, + "train/mom": mom, + "train/last_layer_lr": last_layer_lr, + "train/current_batch_size": current_batch_size, + "train/total_loss": losses_reduced, + **{f"train/{k}": v for k, v in loss_dict_reduced.items()}, + "train/iteration": iteration, + }) + # checkpointing and testing if cfg.evaluation.eval_period_iterations > 0 and (iteration + 1) % cfg.evaluation.eval_period_iterations == 0: @@ -297,7 +350,7 @@ def do_train(cfg, model, resume=False): def main(args): cfg = setup(args) - model = SSLMetaArch(cfg).to(torch.device("cuda")) + model = PyramidSSLMetaArch(cfg).to(torch.device("cuda")) model.prepare_for_distributed_training() logger.info("Model:\n{}".format(model)) @@ -310,7 +363,7 @@ def main(args): ) return do_test(cfg, model, f"manual_{iteration}") - do_train(cfg, model, resume=not args.no_resume) + do_train(cfg, model, args, resume=not args.no_resume) if __name__ == "__main__": diff --git a/dinov2/train/train_pyramid.py b/dinov2/train/train_pyramid.py new file mode 100644 index 000000000..3a653ad7e --- /dev/null +++ b/dinov2/train/train_pyramid.py @@ -0,0 +1,353 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import argparse +import logging +import math +import os +from functools import partial + +from fvcore.common.checkpoint import PeriodicCheckpointer +from omegaconf import OmegaConf +import torch +import wandb + +from dinov2.data import SamplerType, make_data_loader, make_dataset +from dinov2.data import collate_data_and_cast, DataAugmentationDINO, MaskingGenerator +import dinov2.distributed as distributed +from dinov2.fsdp import FSDPCheckpointer +from dinov2.logging import MetricLogger +from dinov2.utils.config import setup +from dinov2.utils.utils import CosineScheduler + +# Import PyramidSSLMetaArch instead of SSLMetaArch +from dinov2.train.ssl_meta_arch_pyramid import PyramidSSLMetaArch + + +torch.backends.cuda.matmul.allow_tf32 = True # PyTorch 1.12 sets this to False by default +logger = logging.getLogger("dinov2") + + +def get_args_parser(add_help=True): + parser = argparse.ArgumentParser("DINOv2 training", add_help=add_help) + parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file") + parser.add_argument( + "--no-resume", + action="store_true", + help="Whether to not attempt to resume from the checkpoint directory. ", + ) + parser.add_argument("--eval-only", action="store_true", help="perform evaluation only") + parser.add_argument("--eval", type=str, default="", help="Eval type to perform") + parser.add_argument( + "opts", + help="Modify config options using the command-line", + default=None, + nargs=argparse.REMAINDER, + ) + parser.add_argument("--output-dir", default="", type=str, help="Output directory to save logs and checkpoints") + parser.add_argument( + "--run_name", + type=str, + help="Name for the wandb run", + default="pyramid_distillation_4gpu", # Default run name + ) + return parser + + +def build_optimizer(cfg, params_groups): + return torch.optim.AdamW(params_groups, betas=(cfg.optim.adamw_beta1, cfg.optim.adamw_beta2)) + + +def build_schedulers(cfg): + OFFICIAL_EPOCH_LENGTH = cfg.train.OFFICIAL_EPOCH_LENGTH + lr = dict( + base_value=cfg.optim.lr, + final_value=cfg.optim.min_lr, + total_iters=cfg.optim.epochs * OFFICIAL_EPOCH_LENGTH, + warmup_iters=cfg.optim.warmup_epochs * OFFICIAL_EPOCH_LENGTH, + start_warmup_value=0, + ) + wd = dict( + base_value=cfg.optim.weight_decay, + final_value=cfg.optim.weight_decay_end, + total_iters=cfg.optim.epochs * OFFICIAL_EPOCH_LENGTH, + ) + momentum = dict( + base_value=cfg.teacher.momentum_teacher, + final_value=cfg.teacher.final_momentum_teacher, + total_iters=cfg.optim.epochs * OFFICIAL_EPOCH_LENGTH, + ) + teacher_temp = dict( + base_value=cfg.teacher.warmup_teacher_temp, + final_value=cfg.teacher.teacher_temp, + total_iters=cfg.teacher.warmup_teacher_temp_epochs * OFFICIAL_EPOCH_LENGTH, + warmup_iters=cfg.teacher.warmup_teacher_temp_epochs * OFFICIAL_EPOCH_LENGTH, + start_warmup_value=cfg.teacher.warmup_teacher_temp, + ) + + lr_schedule = CosineScheduler(**lr) + wd_schedule = CosineScheduler(**wd) + momentum_schedule = CosineScheduler(**momentum) + teacher_temp_schedule = CosineScheduler(**teacher_temp) + last_layer_lr_schedule = CosineScheduler(**lr) + + last_layer_lr_schedule.schedule[ + : cfg.optim.freeze_last_layer_epochs * OFFICIAL_EPOCH_LENGTH + ] = 0 # mimicking the original schedules + + logger.info("Schedulers ready.") + + return ( + lr_schedule, + wd_schedule, + momentum_schedule, + teacher_temp_schedule, + last_layer_lr_schedule, + ) + + +def apply_optim_scheduler(optimizer, lr, wd, last_layer_lr): + for param_group in optimizer.param_groups: + is_last_layer = param_group["is_last_layer"] + lr_multiplier = param_group["lr_multiplier"] + wd_multiplier = param_group["wd_multiplier"] + param_group["weight_decay"] = wd * wd_multiplier + param_group["lr"] = (last_layer_lr if is_last_layer else lr) * lr_multiplier + + +def do_test(cfg, model, iteration): + new_state_dict = model.teacher.state_dict() + + if distributed.is_main_process(): + iter_string = str(iteration) + eval_dir = os.path.join(cfg.train.output_dir, "eval", iter_string) + os.makedirs(eval_dir, exist_ok=True) + # save teacher checkpoint + teacher_ckp_path = os.path.join(eval_dir, "teacher_checkpoint.pth") + torch.save({"teacher": new_state_dict}, teacher_ckp_path) + + +def do_train(cfg, model, resume=False): + model.train() + inputs_dtype = torch.half + fp16_scaler = model.fp16_scaler # for mixed precision training + + # setup optimizer + + optimizer = build_optimizer(cfg, model.get_params_groups()) + ( + lr_schedule, + wd_schedule, + momentum_schedule, + teacher_temp_schedule, + last_layer_lr_schedule, + ) = build_schedulers(cfg) + + # checkpointer + checkpointer = FSDPCheckpointer(model, cfg.train.output_dir, optimizer=optimizer, save_to_disk=True) + + start_iter = checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1 + + OFFICIAL_EPOCH_LENGTH = cfg.train.OFFICIAL_EPOCH_LENGTH + max_iter = cfg.optim.epochs * OFFICIAL_EPOCH_LENGTH + + periodic_checkpointer = PeriodicCheckpointer( + checkpointer, + period=cfg.train.saveckp_freq * OFFICIAL_EPOCH_LENGTH, + max_iter=max_iter, + max_to_keep=3, + ) + + # setup data preprocessing + + img_size = cfg.crops.global_crops_size + patch_size = cfg.student.patch_size + n_tokens = (img_size // patch_size) ** 2 + mask_generator = MaskingGenerator( + input_size=(img_size // patch_size, img_size // patch_size), + max_num_patches=0.5 * img_size // patch_size * img_size // patch_size, + ) + + data_transform = DataAugmentationDINO( + cfg.crops.global_crops_scale, + cfg.crops.local_crops_scale, + cfg.crops.local_crops_number, + global_crops_size=cfg.crops.global_crops_size, + local_crops_size=cfg.crops.local_crops_size, + ) + + collate_fn = partial( + collate_data_and_cast, + mask_ratio_tuple=cfg.ibot.mask_ratio_min_max, + mask_probability=cfg.ibot.mask_sample_probability, + n_tokens=n_tokens, + mask_generator=mask_generator, + dtype=inputs_dtype, + ) + + # setup data loader + + dataset = make_dataset( + dataset_str=cfg.train.dataset_path, + transform=data_transform, + target_transform=lambda _: (), + ) + # sampler_type = SamplerType.INFINITE + sampler_type = SamplerType.SHARDED_INFINITE + data_loader = make_data_loader( + dataset=dataset, + batch_size=cfg.train.batch_size_per_gpu, + num_workers=cfg.train.num_workers, + shuffle=True, + seed=start_iter, # TODO: Fix this -- cfg.train.seed + sampler_type=sampler_type, + sampler_advance=0, # TODO(qas): fix this -- start_iter * cfg.train.batch_size_per_gpu, + drop_last=True, + collate_fn=collate_fn, + ) + + # training loop + + iteration = start_iter + + logger.info("Starting training from iteration {}".format(start_iter)) + metrics_file = os.path.join(cfg.train.output_dir, "training_metrics.json") + metric_logger = MetricLogger(delimiter=" ", output_file=metrics_file) + header = "Training" + + for data in metric_logger.log_every( + data_loader, + 10, + header, + max_iter, + start_iter, + ): + current_batch_size = data["collated_global_crops"].shape[0] / 2 + if iteration > max_iter: + return + + # apply schedules + + lr = lr_schedule[iteration] + wd = wd_schedule[iteration] + mom = momentum_schedule[iteration] + teacher_temp = teacher_temp_schedule[iteration] + last_layer_lr = last_layer_lr_schedule[iteration] + apply_optim_scheduler(optimizer, lr, wd, last_layer_lr) + + # compute losses + + optimizer.zero_grad(set_to_none=True) + loss_dict = model.forward_backward(data, teacher_temp=teacher_temp) + + # clip gradients + grad_norm = 0.0 + if fp16_scaler is not None: + if cfg.optim.clip_grad: + fp16_scaler.unscale_(optimizer) + for v in model.student.values(): + if isinstance(v, torch.nn.ModuleList): + for sub_v in v: + g_norm = sub_v.clip_grad_norm_(cfg.optim.clip_grad) + grad_norm += g_norm.item() ** 2 + else: + g_norm = v.clip_grad_norm_(cfg.optim.clip_grad) + grad_norm += g_norm.item() ** 2 + grad_norm = grad_norm ** 0.5 + fp16_scaler.step(optimizer) + fp16_scaler.update() + else: + if cfg.optim.clip_grad: + for v in model.student.values(): + if isinstance(v, torch.nn.ModuleList): + for sub_v in v: + g_norm = sub_v.clip_grad_norm_(cfg.optim.clip_grad) + grad_norm += g_norm.item() ** 2 + else: + g_norm = v.clip_grad_norm_(cfg.optim.clip_grad) + grad_norm += g_norm.item() ** 2 + grad_norm = grad_norm ** 0.5 + optimizer.step() + + # perform teacher EMA update + + model.update_teacher(mom) + + # logging + + if distributed.get_global_size() > 1: + for v in loss_dict.values(): + torch.distributed.all_reduce(v) + loss_dict_reduced = {k: v.item() / distributed.get_global_size() for k, v in loss_dict.items()} + + if math.isnan(sum(loss_dict_reduced.values())): + logger.info("NaN detected") + raise AssertionError + losses_reduced = sum(loss for loss in loss_dict_reduced.values()) + + metric_logger.update(lr=lr) + metric_logger.update(wd=wd) + metric_logger.update(mom=mom) + metric_logger.update(last_layer_lr=last_layer_lr) + metric_logger.update(current_batch_size=current_batch_size) + metric_logger.update(total_loss=losses_reduced, **loss_dict_reduced) + + if distributed.is_main_process(): + log_dict = { + "train/lr": lr, + "train/wd": wd, + "train/mom": mom, + "train/total_loss": losses_reduced, + "train/grad_norm": grad_norm, + "train/epoch": iteration / OFFICIAL_EPOCH_LENGTH, + "train/iteration": iteration, + } + log_dict.update({f"train/{k}": v for k, v in loss_dict_reduced.items()}) + wandb.log(log_dict) + + # checkpointing and testing + + if cfg.evaluation.eval_period_iterations > 0 and (iteration + 1) % cfg.evaluation.eval_period_iterations == 0: + do_test(cfg, model, iteration) + torch.cuda.synchronize() + + periodic_checkpointer.step(iteration) + + iteration = iteration + 1 + metric_logger.synchronize_between_processes() + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + +def main(args): + cfg = setup(args) + + # Use PyramidSSLMetaArch + model = PyramidSSLMetaArch(cfg).to(torch.device("cuda")) + model.prepare_for_distributed_training() + + if distributed.is_main_process(): + wandb.init( + project="dinov2_pyramid_dl", + config=OmegaConf.to_container(cfg, resolve=True), + resume=not args.no_resume, + name=args.run_name # Use run name from args + ) + + logger.info("Model:\n{}".format(model)) + if args.eval_only: + iteration = ( + FSDPCheckpointer(model, save_dir=cfg.train.output_dir) + .resume_or_load(cfg.MODEL.WEIGHTS, resume=not args.no_resume) + .get("iteration", -1) + + 1 + ) + return do_test(cfg, model, iteration) + + do_train(cfg, model, resume=not args.no_resume) + + +if __name__ == "__main__": + args = get_args_parser().parse_args() + main(args) diff --git a/submit_dinov2_pyramid.sbatch b/submit_dinov2_pyramid.sbatch new file mode 100644 index 000000000..24761c2ec --- /dev/null +++ b/submit_dinov2_pyramid.sbatch @@ -0,0 +1,44 @@ +#!/bin/bash +#SBATCH --partition=a100_long +#SBATCH --gres=gpu:a100:4 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=4 +#SBATCH --mem=160Gb +#SBATCH --time=20-00:00:00 +#SBATCH --job-name=dinov2_pyramid + + +# Load modules +module load python/3.10.10 +module load cuda/12.1 + +# Activate conda environment +source /gpfs/data/shenlab/aj4718/miniconda3/etc/profile.d/conda.sh +conda activate dinov2 + +export WANDB_API_KEY='' +export PYTHONPATH=$PYTHONPATH:. + +# Print environment info +echo "Python: $(which python)" +echo "Nodes: $SLURM_JOB_NODELIST" +echo "GPUs: $SLURM_JOB_GPUS" +echo "========================================" + +# Run the training script +CONFIG="dinov2/configs/train/vits14_pyramid.yaml" +OUTPUT_DIR="/gpfs/data/shenlab/aj4718/dinov2/logs/vits14_pyramid" + +# Create output directory +mkdir -p $OUTPUT_DIR + +# Detect number of GPUs +NUM_GPUS=$(nvidia-smi -L | wc -l) +echo "Detected $NUM_GPUS GPUs" + +# Using torchrun for distributed training +# It will automatically use the detected number of GPUs via --nproc_per_node +torchrun --nproc_per_node=$NUM_GPUS dinov2/train/train_pyramid.py \ + --config-file $CONFIG \ + --output-dir $OUTPUT_DIR \ + --run_name pyramid_distillation_4gpu