Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions dinov2/configs/train/vitg14_reg4_highres.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
dino:
head_n_prototypes: 131072
head_bottleneck_dim: 384
do_kde: True
kde_loss_weight: .05
koleo_loss_weight: 0
do_koleo: False
ibot:
loss_weight: 1.0
mask_sample_probability: 0.5
mask_ratio_min_max:
- 0.1
- 0.45
separate_head: true
head_n_prototypes: 131072
train:
sample_list_path: sample_dataset_highres_1M.txt # magnification-aware: [1, 0.5, 0.25, 0.125] µm/px at 512px
streaming_from_hf: false
streaming_dataset_path: medarc/TCGA-12K-parquet
batch_size_per_gpu: 6
centering: sinkhorn_knopp
use_pretrained: False
teacher_checkpoint_path: /data/OpenMidnight_ckpts/OM_replication_interpolationfix/eval/training_45000/teacher_checkpoint.pth
OFFICIAL_EPOCH_LENGTH: 1250
num_workers: 24
prefetch_factor: 8
skip_checkpointer: false
gradient_accumulation_steps: 4
patch_size_pixels: 448
student:
arch: vit_giant2
patch_size: 14
drop_path_rate: 0.4
ffn_layer: swiglufused
block_chunks: 4
num_register_tokens: 4
teacher:
momentum_teacher: 0.994
optim:
epochs: 96 # 120k iterations / 1250 steps_per_epoch = 96 epochs
early_stop: 96
weight_decay_end: 0.2
base_lr: 1.0e-04
warmup_epochs: 2
layerwise_decay: 1.0
crops:
global_crops_scale:
- 0.32
- 1.0
local_crops_number: 8
local_crops_scale:
- 0.05
- 0.32
global_crops_size: 392
local_crops_size: 168
evaluation:
eval_period_iterations: 5000
bach_root: /block/eva-data/bach
breakhis_root: /block/eva-data/breakhis
pcam_root: /block/eva-data/patch_camelyon
36 changes: 17 additions & 19 deletions dinov2/data/datasets/slide_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,20 @@
from openslide import OpenSlide
import numpy as np
import cv2
from PIL import Image

class SlideDataset(ExtendedVisionDataset):
def __init__(self, root, sample_list_path, *args, **kwargs) -> None:
def __init__(self, root, sample_list_path, *args, patch_size_pixels=224, **kwargs) -> None:
super().__init__(root, *args, **kwargs)
self.sample_list_path = Path(sample_list_path)
self.patch_size_pixels = patch_size_pixels
if not self.sample_list_path.is_file():
raise FileNotFoundError(f"Sample list not found at {self.sample_list_path}")

with self.sample_list_path.open("r") as f:
self.image_files = [line.strip() for line in f if line.strip()]

print(f"This many resolved paths {len(self.image_files)} from {self.sample_list_path}")
print(f"This many resolved paths {len(self.image_files)} from {self.sample_list_path} (patch_size={patch_size_pixels})")

def get_all(self, index):
parts = self.image_files[index].split(" ")
Expand All @@ -28,33 +30,29 @@ def get_all(self, index):
return image, path

def __getitem__(self, index: int) -> Tuple[Any, Any]:
path = self.image_files[index]
parts = path.split(" ")
path, x, y, level = parts
x = int(x)
y = int(y)
level = int(level)

image = OpenSlide(path)

patch_size = 224
height = image.level_dimensions[0][1]
width = image.level_dimensions[0][0]
parts = self.image_files[index].split(" ")
path = parts[0]
x = int(parts[1])
y = int(parts[2])
level = int(parts[3])
read_size = int(parts[4]) if len(parts) >= 5 else self.patch_size_pixels

#read_region is based on the top left pixel in the level 0, not our current level
patch = image.read_region((x, y), level=level, size=(patch_size, patch_size))
slide = OpenSlide(path)
patch = slide.read_region((x, y), level=level, size=(read_size, read_size))
res = patch.convert("RGB")
if read_size != self.patch_size_pixels:
res = res.resize((self.patch_size_pixels, self.patch_size_pixels), Image.BICUBIC)

res = patch.convert("RGB") # Removes alpha - not sure this is the best way to do this thuogh
if self.transforms is not None:
return self.transforms(res, None), index

return res, None, index

def hsv(self, tile_rgb, patch_size):
tile = np.array(tile_rgb)
tile = cv2.cvtColor(tile, cv2.COLOR_RGB2HSV)
min_ratio = .6

lower_bound = np.array([90, 8, 103])
upper_bound = np.array([180, 255, 255])

Expand Down
4 changes: 3 additions & 1 deletion dinov2/data/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _parse_dataset_str(dataset_str: str):

for token in tokens[1:]:
key, value = token.split("=")
assert key in ("root", "extra", "split", "sample_list_path")
assert key in ("root", "extra", "split", "sample_list_path", "patch_size_pixels")
kwargs[key] = value

if name == "ImageNet":
Expand All @@ -61,6 +61,8 @@ def _parse_dataset_str(dataset_str: str):
class_ = ImageNet22k
elif name.lower() == "pathology":
class_ = SlideDataset
if "patch_size_pixels" in kwargs:
kwargs["patch_size_pixels"] = int(kwargs["patch_size_pixels"])
print("kwargs", kwargs)
else:
raise ValueError(f'Unsupported dataset "{name}"')
Expand Down
4 changes: 2 additions & 2 deletions dinov2/train/ssl_meta_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def backprop_loss(self, loss):
else:
loss.backward()

def forward_backward(self, images, teacher_temp):
def forward_backward(self, images, teacher_temp, loss_scale=1.0):
n_global_crops = 2
assert n_global_crops == 2
n_local_crops = self.cfg.crops.local_crops_number
Expand Down Expand Up @@ -355,7 +355,7 @@ def get_teacher_output():
# accumulate loss
loss_accumulator += self.ibot_loss_weight * ibot_patch_loss

self.backprop_loss(loss_accumulator)
self.backprop_loss(loss_accumulator * loss_scale)

self.fsdp_synchronize_streams()

Expand Down
82 changes: 77 additions & 5 deletions dinov2/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pathlib import Path
import gc
import contextlib
from contextlib import ExitStack
import glob

from fvcore.common.checkpoint import PeriodicCheckpointer
Expand Down Expand Up @@ -290,6 +291,53 @@ def _mlp_kind(block):
student_backbone.norm.bias.copy_(model_pretrained.norm.bias)


def _load_from_teacher_checkpoint(cfg, model):
ckpt_path = str(cfg.train.teacher_checkpoint_path)
logger.info("Loading from teacher checkpoint: %s", ckpt_path)
state = torch.load(ckpt_path, map_location="cpu")["teacher"]
state = {k.replace("module.", ""): v for k, v in state.items()}

backbone_state = {k.replace("backbone.", ""): v for k, v in state.items() if k.startswith("backbone.")}
dino_head_state = {k.replace("dino_head.", ""): v for k, v in state.items() if k.startswith("dino_head.")}
ibot_head_state = {k.replace("ibot_head.", ""): v for k, v in state.items() if k.startswith("ibot_head.")}

student_backbone = model.student.backbone
teacher_backbone = model.teacher.backbone

# Interpolate position embeddings if resolution changed
pos_embed = backbone_state["pos_embed"]
n_extra_tokens = 1
cls_pos = pos_embed[:, :n_extra_tokens]
patch_pos = pos_embed[:, n_extra_tokens:]
orig_size = int(patch_pos.shape[1] ** 0.5)
target_h, target_w = student_backbone.patch_embed.patches_resolution

if orig_size != target_h or orig_size != target_w:
logger.info("Interpolating pos_embed from %dx%d to %dx%d", orig_size, orig_size, target_h, target_w)
patch_pos = patch_pos.reshape(1, orig_size, orig_size, -1).permute(0, 3, 1, 2)
patch_pos = F.interpolate(patch_pos, size=(target_h, target_w), mode="bicubic", align_corners=False)
patch_pos = patch_pos.permute(0, 2, 3, 1).reshape(1, target_h * target_w, -1)
backbone_state["pos_embed"] = torch.cat((cls_pos, patch_pos), dim=1)

with torch.no_grad():
msg = student_backbone.load_state_dict(backbone_state, strict=False)
logger.info("Student backbone load: %s", msg)
msg = teacher_backbone.load_state_dict(backbone_state, strict=False)
logger.info("Teacher backbone load: %s", msg)

if dino_head_state:
msg = model.student.dino_head.load_state_dict(dino_head_state, strict=False)
logger.info("Student dino_head load: %s", msg)
msg = model.teacher.dino_head.load_state_dict(dino_head_state, strict=False)
logger.info("Teacher dino_head load: %s", msg)

if ibot_head_state:
msg = model.student.ibot_head.load_state_dict(ibot_head_state, strict=False)
logger.info("Student ibot_head load: %s", msg)
msg = model.teacher.ibot_head.load_state_dict(ibot_head_state, strict=False)
logger.info("Teacher ibot_head load: %s", msg)


def _freeze_student_backbone_except_last_n(cfg, model):
n_unfrozen = cfg.train.unfreeze_last_n_blocks
student_backbone = model.student.backbone
Expand Down Expand Up @@ -391,8 +439,9 @@ def __init__(self, size=224, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.22
]
super().__init__(ops)

eval_size = cfg.crops.global_crops_size
transform = _ResizeAndCrop(
size=224,
size=eval_size,
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
)
Expand Down Expand Up @@ -1119,7 +1168,8 @@ def _worker_init(_):
sample_list_path = str(cfg.train.sample_list_path)
if not sample_list_path:
raise ValueError("cfg.train.sample_list_path must be set when streaming_from_hf is False")
dataset_str = f"pathology:root=/data/TCGA/:sample_list_path={sample_list_path}"
patch_size_pixels = getattr(cfg.train, 'patch_size_pixels', 224)
dataset_str = f"pathology:root=/data/TCGA/:sample_list_path={sample_list_path}:patch_size_pixels={patch_size_pixels}"
dataset = make_dataset(
dataset_str=dataset_str,
transform=data_transform,
Expand All @@ -1141,12 +1191,17 @@ def _worker_init(_):
# training loop

iteration = start_iter
accum_steps = getattr(cfg.train, 'gradient_accumulation_steps', 1)
loss_scale = 1.0 / accum_steps

logger.info("Starting training from iteration {}".format(start_iter))
logger.info("Gradient accumulation steps: %d (loss_scale=%.4f)", accum_steps, loss_scale)
metrics_file = os.path.join(cfg.train.output_dir, "training_metrics.json")
metric_logger = MetricLogger(delimiter=" ", output_file=metrics_file)
header = "Training"

micro_step = 0

for data in metric_logger.log_every(
data_loader,
10,
Expand Down Expand Up @@ -1192,9 +1247,23 @@ def _worker_init(_):

# compute losses

optimizer.zero_grad(set_to_none=True)
if micro_step % accum_steps == 0:
optimizer.zero_grad(set_to_none=True)

is_accumulating = (micro_step % accum_steps) != (accum_steps - 1)
if is_accumulating and accum_steps > 1:
with ExitStack() as stack:
for v in model.student.values():
stack.enter_context(v.no_sync())
loss_dict = model.forward_backward(data, teacher_temp=teacher_temp, loss_scale=loss_scale)
else:
loss_dict = model.forward_backward(data, teacher_temp=teacher_temp, loss_scale=loss_scale)

micro_step += 1

loss_dict = model.forward_backward(data, teacher_temp=teacher_temp)
# only step optimizer and update teacher after accumulation
if micro_step % accum_steps != 0:
continue

# clip gradients

Expand Down Expand Up @@ -1263,7 +1332,10 @@ def main(args):
print(cfg)
model = SSLMetaArch(cfg).to(torch.device("cuda"))
#Load model here from pretrained.
if cfg.train.use_pretrained:
teacher_ckpt = getattr(cfg.train, 'teacher_checkpoint_path', '')
if teacher_ckpt:
_load_from_teacher_checkpoint(cfg, model)
elif cfg.train.use_pretrained:
_load_pretrained_backbone(cfg, model)
_freeze_student_backbone_except_last_n(cfg, model)

Expand Down
6 changes: 4 additions & 2 deletions dinov2/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ def apply_scaling_rules_to_cfg(cfg): # to fix
if cfg.optim.scaling_rule == "sqrt_wrt_1024":
base_lr = cfg.optim.base_lr
cfg.optim.lr = base_lr
cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0)
logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}")
accum = int(getattr(cfg.train, "gradient_accumulation_steps", 1))
effective_batch = cfg.train.batch_size_per_gpu * distributed.get_global_size() * accum
cfg.optim.lr *= math.sqrt(effective_batch / 1024.0)
logger.info(f"sqrt scaling learning rate; base: {base_lr}, effective_batch: {effective_batch}, new: {cfg.optim.lr}")
else:
raise NotImplementedError
return cfg
Expand Down
Loading