diff --git a/dinov2/hub/text/dinov2_wrapper.py b/dinov2/hub/text/dinov2_wrapper.py index 31db29696..42c119470 100644 --- a/dinov2/hub/text/dinov2_wrapper.py +++ b/dinov2/hub/text/dinov2_wrapper.py @@ -1,4 +1,4 @@ -from typing import Sequence +from typing import Sequence, Tuple, Union import torch @@ -26,12 +26,12 @@ def forward(self, img, is_training: bool): def get_intermediate_layers( self, x: torch.Tensor, - n: int | Sequence[int] = 1, # Layers or n last layers to take + n: Union[int, Sequence[int]] = 1, # Layers or n last layers to take reshape: bool = False, return_class_token: bool = False, return_register_tokens: bool = False, norm=True, - ) -> tuple[torch.Tensor] | tuple[tuple[torch.Tensor, ...], ...]: + ) -> Union[Tuple[torch.Tensor], Tuple[Tuple[torch.Tensor, ...], ...]]: if self.model.chunked_blocks: outputs = self.model._get_intermediate_layers_chunked(x, n) else: diff --git a/dinov2/layers/attention.py b/dinov2/layers/attention.py index f1d3dabf1..ff9458914 100644 --- a/dinov2/layers/attention.py +++ b/dinov2/layers/attention.py @@ -11,6 +11,8 @@ import os import warnings +from typing import Optional + import torch from torch import nn, Tensor @@ -55,7 +57,7 @@ def __init__( self.proj_drop = nn.Dropout(proj_drop) def init_weights( - self, init_attn_std: float | None = None, init_proj_std: float | None = None, factor: float = 1.0 + self, init_attn_std: Optional[float] = None, init_proj_std: Optional[float] = None, factor: float = 1.0 ) -> None: init_attn_std = init_attn_std or (self.dim**-0.5) init_proj_std = init_proj_std or init_attn_std * factor diff --git a/dinov2/layers/block.py b/dinov2/layers/block.py index 7e83b71cc..d79f69093 100644 --- a/dinov2/layers/block.py +++ b/dinov2/layers/block.py @@ -147,9 +147,9 @@ def __init__( def init_weights( self, - init_attn_std: float | None = None, - init_proj_std: float | None = None, - init_fc_std: float | None = None, + init_attn_std: Optional[float] = None, + init_proj_std: Optional[float] = None, + init_fc_std: Optional[float] = None, factor: float = 1.0, ) -> None: init_attn_std = init_attn_std or (self.dim**-0.5)