diff --git a/dinov2/configs/train/vits16_hipt.yaml b/dinov2/configs/train/vits16_hipt.yaml new file mode 100644 index 000000000..77526c05a --- /dev/null +++ b/dinov2/configs/train/vits16_hipt.yaml @@ -0,0 +1,2 @@ +start_with_hipt_weights: true +# TODO: Rest of config diff --git a/dinov2/hub/backbones.py b/dinov2/hub/backbones.py index 17e00981f..58248b852 100644 --- a/dinov2/hub/backbones.py +++ b/dinov2/hub/backbones.py @@ -5,6 +5,7 @@ from enum import Enum from typing import Union +from torch.nn.functional import pad import torch @@ -82,3 +83,83 @@ def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Wei return _make_dinov2_model( arch_name="vit_giant2", ffn_layer="swiglufused", weights=weights, pretrained=pretrained, **kwargs ) + + +def dinov2_vits16_hipt(*, pretrained: bool = True, img_size: int = 224, use_teacher_weights: bool = False, **kwargs): + """ + DINOv2 ViT-s/16 model (optionally) pretrained on TCGA. + + Pretrained weights are copied from the HIPT model: + https://github.com/mahmoodlab/HIPT/blob/master/HIPT_4K/Checkpoints/vit256_small_dino.pth + """ + if img_size < 224: + raise NotImplementedError('Shrinking position embeddings is not currently supported') + + model = _make_dinov2_model( + arch_name='vit_small', + img_size=img_size, + patch_size=16, + pretrained=False, + ) + + if pretrained: + if use_teacher_weights: + state_dict_key = 'teacher' + backbone_prefix = 'backbone.' + else: + state_dict_key = 'student' + backbone_prefix = 'module.backbone.' + + hipt_state_dict = torch.hub.load_state_dict_from_url( + 'https://github.com/mahmoodlab/HIPT/raw/a9b5bb8d159684fc4c2c497d68950ab915caeb7e/HIPT_4K/Checkpoints/vit256_small_dino.pth', + map_location="cpu", + ) + hipt_backbone_weights = { + name[len(backbone_prefix):]: params + for name, params in hipt_state_dict[state_dict_key].items() + if name.startswith(backbone_prefix) + } + + hipt_backbone_weights['mask_token'] = model.mask_token + + # Initialise layer scale (gamma) to 1 + for i, block in enumerate(model.blocks): + hipt_backbone_weights[f'blocks.{i}.ls1.gamma'] = torch.ones_like(block.ls1.gamma) + hipt_backbone_weights[f'blocks.{i}.ls2.gamma'] = torch.ones_like(block.ls2.gamma) + + # Changing the input image size for a vision transformer model is tricky. + # + # The established approach is to interpolate the position embeddings, but this is only really + # appropriate when change in input image size implies scale change (in terms of pathology, + # this corresponds to a change in magnification). + # + # For us, changing the input size doesn't mean changing scale---rather, we're adding additional + # surrounding context at the same scale. + # + # I don't think there's a perfect solution to this problem, so I'll just use reflection padding and + # accept that some amount of further training will be required later on. A potential risk of this + # approach is that we are now presenting the model with multiple patches that have the same position + # embedding. The flipped ordering caused by reflection is a bit strange, too. + # + # Another options would be random initialisation, but that seems risky given that the model has + # never seen those. + pos_embed_hipt = hipt_backbone_weights.pop('pos_embed') + num_patches = model.patch_embed.num_patches + class_pos_embed = pos_embed_hipt[:, 0:1] + patch_pos_embed = pos_embed_hipt[:, 1:] + old_num_patches = pos_embed_hipt.shape[1] + old_sz = round(old_num_patches ** 0.5) + sz = round(num_patches ** 0.5) + sz_diff = sz - old_sz + patch_pos_embed = patch_pos_embed.view(14, 14, 384) + pad_before = sz_diff // 2 + pad_after = sz_diff - pad_before + new_patch_pos_embed = pad(patch_pos_embed.permute(2, 0, 1), (pad_before, pad_after, pad_before, pad_after), mode='reflect').permute(1, 2, 0) + hipt_backbone_weights['pos_embed'] = torch.cat([ + class_pos_embed, + new_patch_pos_embed.view(1, num_patches, model.num_features), + ], axis=1) + + model.load_state_dict(hipt_backbone_weights) + + return model diff --git a/dinov2/train/ssl_meta_arch.py b/dinov2/train/ssl_meta_arch.py index 3ccf15e90..c2975e572 100644 --- a/dinov2/train/ssl_meta_arch.py +++ b/dinov2/train/ssl_meta_arch.py @@ -37,9 +37,16 @@ def __init__(self, cfg): student_model_dict = dict() teacher_model_dict = dict() - student_backbone, teacher_backbone, embed_dim = build_model_from_cfg(cfg) - student_model_dict["backbone"] = student_backbone - teacher_model_dict["backbone"] = teacher_backbone + if cfg.start_with_hipt_weights: + from dinov2.hub.backbones import dinov2_vits16_hipt + img_size = cfg.crops.global_crops_size + student_backbone = dinov2_vits16_hipt(img_size=img_size, use_teacher_weights=False) + teacher_backbone = dinov2_vits16_hipt(img_size=img_size, use_teacher_weights=True) + embed_dim = student_backbone.embed_dim + else: + student_backbone, teacher_backbone, embed_dim = build_model_from_cfg(cfg) + student_model_dict["backbone"] = student_backbone + teacher_model_dict["backbone"] = teacher_backbone logger.info(f"OPTIONS -- architecture : embed_dim: {embed_dim}") if cfg.student.pretrained_weights: