Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions dinov2/configs/train/vits14_pyramid.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
dino:
head_n_prototypes: 65536
head_bottleneck_dim: 256
head_hidden_dim: 2048
head_nlayers: 3
loss_weight: 1.0
koleo_loss_weight: 0.1
ibot:
loss_weight: 1.0
mask_sample_probability: 0.5
mask_ratio_min_max:
- 0.1
- 0.5
separate_head: true
head_n_prototypes: 65536
head_bottleneck_dim: 256
head_hidden_dim: 2048
head_nlayers: 3
train:
batch_size_per_gpu: 64
dataset_path: HuggingFace:root=tsbpp/fall2025_deeplearning:split=train
centering: sinkhorn_knopp
output_dir: /gpfs/data/shenlab/aj4718/dinov2/logs/vits14_pyramid
OFFICIAL_EPOCH_LENGTH: 7813
num_workers: 4
saveckp_freq: 1
student:
arch: vit_small
patch_size: 16
drop_path_rate: 0.4
layerscale: 1.0e-05
drop_path_uniform: true
pretrained_weights: ''
ffn_layer: mlp
block_chunks: 0
teacher:
momentum_teacher: 0.994
final_momentum_teacher: 1.0
warmup_teacher_temp: 0.04
teacher_temp: 0.07
warmup_teacher_temp_epochs: 30
optim:
epochs: 300
weight_decay: 0.04
weight_decay_end: 0.4
base_lr: 0.002
lr: 0.0 # will be set by apply_scaling_rules_to_cfg
warmup_epochs: 10
min_lr: 1.0e-06
clip_grad: 3.0
freeze_last_layer_epochs: 1
scaling_rule: sqrt_wrt_1024
patch_embed_lr_mult: 0.2
layerwise_decay: 0.9
adamw_beta1: 0.9
adamw_beta2: 0.999
crops:
global_crops_scale:
- 0.32
- 1.0
local_crops_number: 8
local_crops_scale:
- 0.05
- 0.32
global_crops_size: 96
local_crops_size: 48
evaluation:
eval_period_iterations: 10000
1 change: 1 addition & 0 deletions dinov2/data/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@

from .image_net import ImageNet
from .image_net_22k import ImageNet22k
from .huggingface import HuggingFaceDataset
52 changes: 52 additions & 0 deletions dinov2/data/datasets/huggingface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

from typing import Any, Callable, Optional, Tuple
import logging

from datasets import load_dataset
from PIL import Image
import torch
from torchvision.datasets import VisionDataset

logger = logging.getLogger("dinov2")

class HuggingFaceDataset(VisionDataset):
def __init__(
self,
root: str,
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
self.split = split
self.dataset_name = root # Reusing root as dataset name for consistency with other datasets

logger.info(f"Loading HuggingFace dataset: {self.dataset_name}, split: {self.split}")
self.dataset = load_dataset(self.dataset_name, split=self.split)
logger.info(f"Loaded {len(self.dataset)} samples from {self.dataset_name}")

def __getitem__(self, index: int) -> Tuple[Any, Any]:
item = self.dataset[index]

# Assumes 'image' column. Modify if needed.
image = item['image']

if image.mode != 'RGB':
image = image.convert('RGB')

target = 0 # Dummy target for SSL

if self.transform is not None:
image = self.transform(image)

if self.target_transform is not None:
target = self.target_transform(target)

return image, target

def __len__(self) -> int:
return len(self.dataset)
4 changes: 3 additions & 1 deletion dinov2/data/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch
from torch.utils.data import Sampler

from .datasets import ImageNet, ImageNet22k
from .datasets import ImageNet, ImageNet22k, HuggingFaceDataset
from .samplers import EpochSampler, InfiniteSampler, ShardedInfiniteSampler


Expand Down Expand Up @@ -58,6 +58,8 @@ def _parse_dataset_str(dataset_str: str):
kwargs["split"] = ImageNet.Split[kwargs["split"]]
elif name == "ImageNet22k":
class_ = ImageNet22k
elif name == "HuggingFace":
class_ = HuggingFaceDataset
else:
raise ValueError(f'Unsupported dataset "{name}"')

Expand Down
Loading