Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dinov2/configs/train/vits16_hipt.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
start_with_hipt_weights: true
# TODO: Rest of config
81 changes: 81 additions & 0 deletions dinov2/hub/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from enum import Enum
from typing import Union
from torch.nn.functional import pad

import torch

Expand Down Expand Up @@ -82,3 +83,83 @@ def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Wei
return _make_dinov2_model(
arch_name="vit_giant2", ffn_layer="swiglufused", weights=weights, pretrained=pretrained, **kwargs
)


def dinov2_vits16_hipt(*, pretrained: bool = True, img_size: int = 224, use_teacher_weights: bool = False, **kwargs):
"""
DINOv2 ViT-s/16 model (optionally) pretrained on TCGA.

Pretrained weights are copied from the HIPT model:
https://github.com/mahmoodlab/HIPT/blob/master/HIPT_4K/Checkpoints/vit256_small_dino.pth
"""
if img_size < 224:
raise NotImplementedError('Shrinking position embeddings is not currently supported')

model = _make_dinov2_model(
arch_name='vit_small',
img_size=img_size,
patch_size=16,
pretrained=False,
)

if pretrained:
if use_teacher_weights:
state_dict_key = 'teacher'
backbone_prefix = 'backbone.'
else:
state_dict_key = 'student'
backbone_prefix = 'module.backbone.'

hipt_state_dict = torch.hub.load_state_dict_from_url(
'https://github.com/mahmoodlab/HIPT/raw/a9b5bb8d159684fc4c2c497d68950ab915caeb7e/HIPT_4K/Checkpoints/vit256_small_dino.pth',
map_location="cpu",
)
hipt_backbone_weights = {
name[len(backbone_prefix):]: params
for name, params in hipt_state_dict[state_dict_key].items()
if name.startswith(backbone_prefix)
}

hipt_backbone_weights['mask_token'] = model.mask_token

# Initialise layer scale (gamma) to 1
for i, block in enumerate(model.blocks):
hipt_backbone_weights[f'blocks.{i}.ls1.gamma'] = torch.ones_like(block.ls1.gamma)
hipt_backbone_weights[f'blocks.{i}.ls2.gamma'] = torch.ones_like(block.ls2.gamma)

# Changing the input image size for a vision transformer model is tricky.
#
# The established approach is to interpolate the position embeddings, but this is only really
# appropriate when change in input image size implies scale change (in terms of pathology,
# this corresponds to a change in magnification).
#
# For us, changing the input size doesn't mean changing scale---rather, we're adding additional
# surrounding context at the same scale.
#
# I don't think there's a perfect solution to this problem, so I'll just use reflection padding and
# accept that some amount of further training will be required later on. A potential risk of this
# approach is that we are now presenting the model with multiple patches that have the same position
# embedding. The flipped ordering caused by reflection is a bit strange, too.
#
# Another options would be random initialisation, but that seems risky given that the model has
# never seen those.
pos_embed_hipt = hipt_backbone_weights.pop('pos_embed')
num_patches = model.patch_embed.num_patches
class_pos_embed = pos_embed_hipt[:, 0:1]
patch_pos_embed = pos_embed_hipt[:, 1:]
old_num_patches = pos_embed_hipt.shape[1]
old_sz = round(old_num_patches ** 0.5)
sz = round(num_patches ** 0.5)
sz_diff = sz - old_sz
patch_pos_embed = patch_pos_embed.view(14, 14, 384)
pad_before = sz_diff // 2
pad_after = sz_diff - pad_before
new_patch_pos_embed = pad(patch_pos_embed.permute(2, 0, 1), (pad_before, pad_after, pad_before, pad_after), mode='reflect').permute(1, 2, 0)
hipt_backbone_weights['pos_embed'] = torch.cat([
class_pos_embed,
new_patch_pos_embed.view(1, num_patches, model.num_features),
], axis=1)

model.load_state_dict(hipt_backbone_weights)

return model
13 changes: 10 additions & 3 deletions dinov2/train/ssl_meta_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,16 @@ def __init__(self, cfg):
student_model_dict = dict()
teacher_model_dict = dict()

student_backbone, teacher_backbone, embed_dim = build_model_from_cfg(cfg)
student_model_dict["backbone"] = student_backbone
teacher_model_dict["backbone"] = teacher_backbone
if cfg.start_with_hipt_weights:
from dinov2.hub.backbones import dinov2_vits16_hipt
img_size = cfg.crops.global_crops_size
student_backbone = dinov2_vits16_hipt(img_size=img_size, use_teacher_weights=False)
teacher_backbone = dinov2_vits16_hipt(img_size=img_size, use_teacher_weights=True)
embed_dim = student_backbone.embed_dim
else:
student_backbone, teacher_backbone, embed_dim = build_model_from_cfg(cfg)
student_model_dict["backbone"] = student_backbone
teacher_model_dict["backbone"] = teacher_backbone
logger.info(f"OPTIONS -- architecture : embed_dim: {embed_dim}")

if cfg.student.pretrained_weights:
Expand Down