diff --git a/dinov3/hub/backbones.py b/dinov3/hub/backbones.py index 3d83e2f..1f30e47 100644 --- a/dinov3/hub/backbones.py +++ b/dinov3/hub/backbones.py @@ -7,12 +7,21 @@ from enum import Enum from typing import List, Optional, Union from urllib.parse import urlparse +from urllib.request import url2pathname from pathlib import Path import torch from .utils import DINOV3_BASE_URL +LOADERS = { + "file": lambda url, **_: torch.load( + Path(url2pathname(urlparse(url).path)), map_location="cpu", weights_only=True + ), + "http": lambda url, **kwargs: torch.hub.load_state_dict_from_url(url, map_location="cpu", **kwargs), + "https": lambda url, **kwargs: torch.hub.load_state_dict_from_url(url, map_location="cpu", **kwargs), +} + class Weights(Enum): LVD1689M = "LVD1689M" @@ -30,6 +39,18 @@ def convert_path_or_url_to_url(path: str) -> str: return Path(path).expanduser().resolve().as_uri() +def _load_state_dict(url: str, check_hash: bool = False) -> dict[str, torch.Tensor]: + scheme = urlparse(url).scheme + if scheme not in LOADERS: + raise ValueError( + f"Unsupported URL scheme: {scheme}. Supported schemes are: {', '.join(LOADERS.keys())}." + ) + + loader = LOADERS[scheme] + kwargs = {"check_hash": check_hash} if scheme in ("http", "https") else {} + return loader(url, **kwargs) + + def _make_dinov3_vit_model_arch( *, patch_size: int = 16, @@ -137,7 +158,7 @@ def _make_dinov3_vit( ) else: url = convert_path_or_url_to_url(weights) - state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu", check_hash=check_hash) + state_dict = _load_state_dict(url, check_hash=check_hash) model.load_state_dict(state_dict, strict=True) else: model.init_weights()