diff --git a/dinov2/configs/train/vitg14_reg4_highres.yaml b/dinov2/configs/train/vitg14_reg4_highres.yaml new file mode 100644 index 0000000..80b7b68 --- /dev/null +++ b/dinov2/configs/train/vitg14_reg4_highres.yaml @@ -0,0 +1,60 @@ +dino: + head_n_prototypes: 131072 + head_bottleneck_dim: 384 + do_kde: True + kde_loss_weight: .05 + koleo_loss_weight: 0 + do_koleo: False +ibot: + loss_weight: 1.0 + mask_sample_probability: 0.5 + mask_ratio_min_max: + - 0.1 + - 0.45 + separate_head: true + head_n_prototypes: 131072 +train: + sample_list_path: sample_dataset_highres_1M.txt # magnification-aware: [1, 0.5, 0.25, 0.125] µm/px at 512px + streaming_from_hf: false + streaming_dataset_path: medarc/TCGA-12K-parquet + batch_size_per_gpu: 6 + centering: sinkhorn_knopp + use_pretrained: False + teacher_checkpoint_path: /data/OpenMidnight_ckpts/OM_replication_interpolationfix/eval/training_45000/teacher_checkpoint.pth + OFFICIAL_EPOCH_LENGTH: 1250 + num_workers: 24 + prefetch_factor: 8 + skip_checkpointer: false + gradient_accumulation_steps: 4 + patch_size_pixels: 448 +student: + arch: vit_giant2 + patch_size: 14 + drop_path_rate: 0.4 + ffn_layer: swiglufused + block_chunks: 4 + num_register_tokens: 4 +teacher: + momentum_teacher: 0.994 +optim: + epochs: 96 # 120k iterations / 1250 steps_per_epoch = 96 epochs + early_stop: 96 + weight_decay_end: 0.2 + base_lr: 1.0e-04 + warmup_epochs: 2 + layerwise_decay: 1.0 +crops: + global_crops_scale: + - 0.32 + - 1.0 + local_crops_number: 8 + local_crops_scale: + - 0.05 + - 0.32 + global_crops_size: 392 + local_crops_size: 168 +evaluation: + eval_period_iterations: 5000 + bach_root: /block/eva-data/bach + breakhis_root: /block/eva-data/breakhis + pcam_root: /block/eva-data/patch_camelyon diff --git a/dinov2/data/datasets/slide_dataset.py b/dinov2/data/datasets/slide_dataset.py index d31a1ed..0e20224 100644 --- a/dinov2/data/datasets/slide_dataset.py +++ b/dinov2/data/datasets/slide_dataset.py @@ -8,18 +8,20 @@ from openslide import OpenSlide import numpy as np import cv2 +from PIL import Image class SlideDataset(ExtendedVisionDataset): - def __init__(self, root, sample_list_path, *args, **kwargs) -> None: + def __init__(self, root, sample_list_path, *args, patch_size_pixels=224, **kwargs) -> None: super().__init__(root, *args, **kwargs) self.sample_list_path = Path(sample_list_path) + self.patch_size_pixels = patch_size_pixels if not self.sample_list_path.is_file(): raise FileNotFoundError(f"Sample list not found at {self.sample_list_path}") with self.sample_list_path.open("r") as f: self.image_files = [line.strip() for line in f if line.strip()] - print(f"This many resolved paths {len(self.image_files)} from {self.sample_list_path}") + print(f"This many resolved paths {len(self.image_files)} from {self.sample_list_path} (patch_size={patch_size_pixels})") def get_all(self, index): parts = self.image_files[index].split(" ") @@ -28,33 +30,29 @@ def get_all(self, index): return image, path def __getitem__(self, index: int) -> Tuple[Any, Any]: - path = self.image_files[index] - parts = path.split(" ") - path, x, y, level = parts - x = int(x) - y = int(y) - level = int(level) - - image = OpenSlide(path) - - patch_size = 224 - height = image.level_dimensions[0][1] - width = image.level_dimensions[0][0] + parts = self.image_files[index].split(" ") + path = parts[0] + x = int(parts[1]) + y = int(parts[2]) + level = int(parts[3]) + read_size = int(parts[4]) if len(parts) >= 5 else self.patch_size_pixels - #read_region is based on the top left pixel in the level 0, not our current level - patch = image.read_region((x, y), level=level, size=(patch_size, patch_size)) + slide = OpenSlide(path) + patch = slide.read_region((x, y), level=level, size=(read_size, read_size)) + res = patch.convert("RGB") + if read_size != self.patch_size_pixels: + res = res.resize((self.patch_size_pixels, self.patch_size_pixels), Image.BICUBIC) - res = patch.convert("RGB") # Removes alpha - not sure this is the best way to do this thuogh if self.transforms is not None: return self.transforms(res, None), index return res, None, index - + def hsv(self, tile_rgb, patch_size): tile = np.array(tile_rgb) tile = cv2.cvtColor(tile, cv2.COLOR_RGB2HSV) min_ratio = .6 - + lower_bound = np.array([90, 8, 103]) upper_bound = np.array([180, 255, 255]) diff --git a/dinov2/data/loaders.py b/dinov2/data/loaders.py index 3f0a3e0..43f629c 100644 --- a/dinov2/data/loaders.py +++ b/dinov2/data/loaders.py @@ -50,7 +50,7 @@ def _parse_dataset_str(dataset_str: str): for token in tokens[1:]: key, value = token.split("=") - assert key in ("root", "extra", "split", "sample_list_path") + assert key in ("root", "extra", "split", "sample_list_path", "patch_size_pixels") kwargs[key] = value if name == "ImageNet": @@ -61,6 +61,8 @@ def _parse_dataset_str(dataset_str: str): class_ = ImageNet22k elif name.lower() == "pathology": class_ = SlideDataset + if "patch_size_pixels" in kwargs: + kwargs["patch_size_pixels"] = int(kwargs["patch_size_pixels"]) print("kwargs", kwargs) else: raise ValueError(f'Unsupported dataset "{name}"') diff --git a/dinov2/train/ssl_meta_arch.py b/dinov2/train/ssl_meta_arch.py index db95104..a7238c7 100644 --- a/dinov2/train/ssl_meta_arch.py +++ b/dinov2/train/ssl_meta_arch.py @@ -133,7 +133,7 @@ def backprop_loss(self, loss): else: loss.backward() - def forward_backward(self, images, teacher_temp): + def forward_backward(self, images, teacher_temp, loss_scale=1.0): n_global_crops = 2 assert n_global_crops == 2 n_local_crops = self.cfg.crops.local_crops_number @@ -355,7 +355,7 @@ def get_teacher_output(): # accumulate loss loss_accumulator += self.ibot_loss_weight * ibot_patch_loss - self.backprop_loss(loss_accumulator) + self.backprop_loss(loss_accumulator * loss_scale) self.fsdp_synchronize_streams() diff --git a/dinov2/train/train.py b/dinov2/train/train.py index a824e63..dd67480 100644 --- a/dinov2/train/train.py +++ b/dinov2/train/train.py @@ -14,6 +14,7 @@ from pathlib import Path import gc import contextlib +from contextlib import ExitStack import glob from fvcore.common.checkpoint import PeriodicCheckpointer @@ -290,6 +291,53 @@ def _mlp_kind(block): student_backbone.norm.bias.copy_(model_pretrained.norm.bias) +def _load_from_teacher_checkpoint(cfg, model): + ckpt_path = str(cfg.train.teacher_checkpoint_path) + logger.info("Loading from teacher checkpoint: %s", ckpt_path) + state = torch.load(ckpt_path, map_location="cpu")["teacher"] + state = {k.replace("module.", ""): v for k, v in state.items()} + + backbone_state = {k.replace("backbone.", ""): v for k, v in state.items() if k.startswith("backbone.")} + dino_head_state = {k.replace("dino_head.", ""): v for k, v in state.items() if k.startswith("dino_head.")} + ibot_head_state = {k.replace("ibot_head.", ""): v for k, v in state.items() if k.startswith("ibot_head.")} + + student_backbone = model.student.backbone + teacher_backbone = model.teacher.backbone + + # Interpolate position embeddings if resolution changed + pos_embed = backbone_state["pos_embed"] + n_extra_tokens = 1 + cls_pos = pos_embed[:, :n_extra_tokens] + patch_pos = pos_embed[:, n_extra_tokens:] + orig_size = int(patch_pos.shape[1] ** 0.5) + target_h, target_w = student_backbone.patch_embed.patches_resolution + + if orig_size != target_h or orig_size != target_w: + logger.info("Interpolating pos_embed from %dx%d to %dx%d", orig_size, orig_size, target_h, target_w) + patch_pos = patch_pos.reshape(1, orig_size, orig_size, -1).permute(0, 3, 1, 2) + patch_pos = F.interpolate(patch_pos, size=(target_h, target_w), mode="bicubic", align_corners=False) + patch_pos = patch_pos.permute(0, 2, 3, 1).reshape(1, target_h * target_w, -1) + backbone_state["pos_embed"] = torch.cat((cls_pos, patch_pos), dim=1) + + with torch.no_grad(): + msg = student_backbone.load_state_dict(backbone_state, strict=False) + logger.info("Student backbone load: %s", msg) + msg = teacher_backbone.load_state_dict(backbone_state, strict=False) + logger.info("Teacher backbone load: %s", msg) + + if dino_head_state: + msg = model.student.dino_head.load_state_dict(dino_head_state, strict=False) + logger.info("Student dino_head load: %s", msg) + msg = model.teacher.dino_head.load_state_dict(dino_head_state, strict=False) + logger.info("Teacher dino_head load: %s", msg) + + if ibot_head_state: + msg = model.student.ibot_head.load_state_dict(ibot_head_state, strict=False) + logger.info("Student ibot_head load: %s", msg) + msg = model.teacher.ibot_head.load_state_dict(ibot_head_state, strict=False) + logger.info("Teacher ibot_head load: %s", msg) + + def _freeze_student_backbone_except_last_n(cfg, model): n_unfrozen = cfg.train.unfreeze_last_n_blocks student_backbone = model.student.backbone @@ -391,8 +439,9 @@ def __init__(self, size=224, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.22 ] super().__init__(ops) + eval_size = cfg.crops.global_crops_size transform = _ResizeAndCrop( - size=224, + size=eval_size, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], ) @@ -1119,7 +1168,8 @@ def _worker_init(_): sample_list_path = str(cfg.train.sample_list_path) if not sample_list_path: raise ValueError("cfg.train.sample_list_path must be set when streaming_from_hf is False") - dataset_str = f"pathology:root=/data/TCGA/:sample_list_path={sample_list_path}" + patch_size_pixels = getattr(cfg.train, 'patch_size_pixels', 224) + dataset_str = f"pathology:root=/data/TCGA/:sample_list_path={sample_list_path}:patch_size_pixels={patch_size_pixels}" dataset = make_dataset( dataset_str=dataset_str, transform=data_transform, @@ -1141,12 +1191,17 @@ def _worker_init(_): # training loop iteration = start_iter + accum_steps = getattr(cfg.train, 'gradient_accumulation_steps', 1) + loss_scale = 1.0 / accum_steps logger.info("Starting training from iteration {}".format(start_iter)) + logger.info("Gradient accumulation steps: %d (loss_scale=%.4f)", accum_steps, loss_scale) metrics_file = os.path.join(cfg.train.output_dir, "training_metrics.json") metric_logger = MetricLogger(delimiter=" ", output_file=metrics_file) header = "Training" + micro_step = 0 + for data in metric_logger.log_every( data_loader, 10, @@ -1192,9 +1247,23 @@ def _worker_init(_): # compute losses - optimizer.zero_grad(set_to_none=True) + if micro_step % accum_steps == 0: + optimizer.zero_grad(set_to_none=True) + + is_accumulating = (micro_step % accum_steps) != (accum_steps - 1) + if is_accumulating and accum_steps > 1: + with ExitStack() as stack: + for v in model.student.values(): + stack.enter_context(v.no_sync()) + loss_dict = model.forward_backward(data, teacher_temp=teacher_temp, loss_scale=loss_scale) + else: + loss_dict = model.forward_backward(data, teacher_temp=teacher_temp, loss_scale=loss_scale) + + micro_step += 1 - loss_dict = model.forward_backward(data, teacher_temp=teacher_temp) + # only step optimizer and update teacher after accumulation + if micro_step % accum_steps != 0: + continue # clip gradients @@ -1263,7 +1332,10 @@ def main(args): print(cfg) model = SSLMetaArch(cfg).to(torch.device("cuda")) #Load model here from pretrained. - if cfg.train.use_pretrained: + teacher_ckpt = getattr(cfg.train, 'teacher_checkpoint_path', '') + if teacher_ckpt: + _load_from_teacher_checkpoint(cfg, model) + elif cfg.train.use_pretrained: _load_pretrained_backbone(cfg, model) _freeze_student_backbone_except_last_n(cfg, model) diff --git a/dinov2/utils/config.py b/dinov2/utils/config.py index c9de578..b949fe2 100644 --- a/dinov2/utils/config.py +++ b/dinov2/utils/config.py @@ -22,8 +22,10 @@ def apply_scaling_rules_to_cfg(cfg): # to fix if cfg.optim.scaling_rule == "sqrt_wrt_1024": base_lr = cfg.optim.base_lr cfg.optim.lr = base_lr - cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0) - logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}") + accum = int(getattr(cfg.train, "gradient_accumulation_steps", 1)) + effective_batch = cfg.train.batch_size_per_gpu * distributed.get_global_size() * accum + cfg.optim.lr *= math.sqrt(effective_batch / 1024.0) + logger.info(f"sqrt scaling learning rate; base: {base_lr}, effective_batch: {effective_batch}, new: {cfg.optim.lr}") else: raise NotImplementedError return cfg diff --git a/prepatching_scripts/create_sample_dataset_txt_highres.py b/prepatching_scripts/create_sample_dataset_txt_highres.py new file mode 100644 index 0000000..aa62091 --- /dev/null +++ b/prepatching_scripts/create_sample_dataset_txt_highres.py @@ -0,0 +1,127 @@ +import cv2 +import random +from pathlib import Path +from openslide import OpenSlide +import numpy as np +import argparse +from tqdm import tqdm +from concurrent.futures import ProcessPoolExecutor + +parser = argparse.ArgumentParser() +parser.add_argument("--data_root", type=str, default="/block/TCGA") +parser.add_argument("--output", type=str, default="sample_dataset_highres_2M.txt") +parser.add_argument("--target_patches", type=int, default=1_000_000) +parser.add_argument("--max_tries", type=int, default=1000) +parser.add_argument("--patches_per_mag", type=int, default=5) +parser.add_argument("--workers", type=int, default=16) +args = parser.parse_args() + +TARGET_MPPS = [1.0, 0.5, 0.25, 0.125] +TILE_SIZE = 448 +MAX_UPSCALE = 2.0 +MPP_KEY = "openslide.mpp-x" + + +def hsv_tissue_check(tile_rgb): + tile = np.array(tile_rgb) + tile = cv2.cvtColor(tile, cv2.COLOR_RGB2HSV) + mask = cv2.inRange(tile, np.array([90, 8, 103]), np.array([180, 255, 255])) + return np.count_nonzero(mask) / mask.size > 0.6 + + +def process_slide(task): + path, seed, patches_per_mag, max_tries = task + random.seed(seed) + results = [] + + try: + slide = OpenSlide(path) + except Exception: + return results + + if MPP_KEY not in slide.properties: + slide.close() + return results + + native_mpp = float(slide.properties[MPP_KEY]) + lv0_w, lv0_h = slide.level_dimensions[0] + + for target_mpp in TARGET_MPPS: + if native_mpp > MAX_UPSCALE * target_mpp: + continue + + target_ds = target_mpp / native_mpp + best_level = 0 + for l in range(slide.level_count): + if slide.level_downsamples[l] <= target_ds: + best_level = l + + level_mpp = native_mpp * slide.level_downsamples[best_level] + read_size = int(round(TILE_SIZE * target_mpp / level_mpp)) + + physical_lv0 = int(read_size * slide.level_downsamples[best_level]) + max_x = lv0_w - physical_lv0 + max_y = lv0_h - physical_lv0 + if max_x <= 0 or max_y <= 0: + continue + + collected = 0 + tries = 0 + while collected < patches_per_mag and tries < max_tries: + tries += 1 + x = random.randint(0, max_x) + y = random.randint(0, max_y) + patch = slide.read_region((x, y), level=best_level, size=(read_size, read_size)).convert("RGB") + if hsv_tissue_check(patch): + results.append(f"{path} {x} {y} {best_level} {read_size}\n") + collected += 1 + + slide.close() + return results + + +data_root = Path(args.data_root) +svs_files = sorted(str(p) for p in data_root.rglob("*.svs")) +if not svs_files: + raise RuntimeError(f"No SVS files found under {data_root}") + +print(f"Found {len(svs_files)} SVS files") +print(f"Target magnifications (um/px): {TARGET_MPPS}") +print(f"Tile size: {TILE_SIZE}px, max upscale: {MAX_UPSCALE}x") +print(f"Target: {args.target_patches} patches, workers: {args.workers}") + +total = 0 +pass_idx = 0 +pbar = tqdm(total=args.target_patches, desc="Patches") +with open(args.output, 'w') as f: + while total < args.target_patches: + tasks = [(path, pass_idx * 100000 + i, args.patches_per_mag, args.max_tries) + for i, path in enumerate(svs_files)] + random.shuffle(tasks) + + with ProcessPoolExecutor(max_workers=args.workers) as executor: + for results in executor.map(process_slide, tasks): + for line in results: + if total >= args.target_patches: + break + f.write(line) + total += 1 + pbar.update(1) + if total >= args.target_patches: + break + + pass_idx += 1 + print(f"Pass {pass_idx} complete, {total} patches so far") + +pbar.close() +print(f"Generated {total} patches. Shuffling...") + +with open(args.output, 'r') as f: + lines = f.readlines() + +random.shuffle(lines) + +with open(args.output, 'w') as f: + f.writelines(lines) + +print("Done") diff --git a/run_highres_finetune.sh b/run_highres_finetune.sh new file mode 100755 index 0000000..0b27e86 --- /dev/null +++ b/run_highres_finetune.sh @@ -0,0 +1,50 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Master node configuration +export MASTER_ADDR=$(hostname -I | awk '{print $1}') +export MASTER_PORT=29500 + +export NNODES=1 # number of nodes you are using +export NPROC_PER_NODE=8 # number of GPUs per node +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" # specific devices to use on this node +export NODE_RANK=0 # the node running this script will be master node (rank 0) + +# Training config +CONFIG_FILE="./dinov2/configs/train/vitg14_reg4_highres.yaml" +OUTPUT_DIR="./output_vitg14_highres" +RESUME="True" # set string to "True" to resume from last checkpoint in OUTPUT_DIR if it exists + +# Set Python path for imports +REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd -P)" +export DINOV2_RUN_SCRIPT="${REPO_ROOT}/$(basename "${BASH_SOURCE[0]}")" +export PYTHONPATH="${REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" + +# Clean output directory only when not resuming +if [[ "${RESUME}" == "True" ]]; then + echo "Resume enabled; preserving ${OUTPUT_DIR}" + RESUME_FLAG="" +else + echo "Resume disabled; cleaning ${OUTPUT_DIR}" + rm -rf "${OUTPUT_DIR}" + RESUME_FLAG="--no-resume" +fi +mkdir -p "${OUTPUT_DIR}" + +echo "[High-Res Finetuning] Starting training..." +echo "MASTER_ADDR=${MASTER_ADDR}" +echo "MASTER_PORT=${MASTER_PORT}" +echo "NNODES=${NNODES}, NPROC_PER_NODE=${NPROC_PER_NODE}" +echo "CONFIG_FILE=${CONFIG_FILE}" +echo "OUTPUT_DIR=${OUTPUT_DIR}" + +uv run torchrun \ + --nnodes "${NNODES}" \ + --nproc_per_node "${NPROC_PER_NODE}" \ + --node_rank "${NODE_RANK}" \ + --master_addr "${MASTER_ADDR}" \ + --master_port "${MASTER_PORT}" \ + dinov2/train/train.py \ + --config-file "${CONFIG_FILE}" \ + --output-dir "${OUTPUT_DIR}" \ + ${RESUME_FLAG}