diff --git a/engine/data/transforms/_transforms.py b/engine/data/transforms/_transforms.py index 31588df5..b8318031 100644 --- a/engine/data/transforms/_transforms.py +++ b/engine/data/transforms/_transforms.py @@ -70,6 +70,10 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = self._fill[type(inpt)] padding = params['padding'] return F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type] + + # Compatibility with torchvision>0.20.0 + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return self._transform(inpt, params) def __call__(self, *inputs: Any) -> Any: outputs = super().forward(*inputs) @@ -91,6 +95,7 @@ def __call__(self, *inputs: Any) -> Any: return super().forward(*inputs) + @register() class ConvertBoxes(T.Transform): _transformed_types = ( @@ -112,6 +117,19 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: inpt = inpt / torch.tensor(spatial_size[::-1]).tile(2)[None] return inpt + + # Compatibility with torchvision>0.20.0 + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + spatial_size = getattr(inpt, _boxes_keys[1]) + if self.fmt: + in_fmt = inpt.format.value.lower() + inpt = torchvision.ops.box_convert(inpt, in_fmt=in_fmt, out_fmt=self.fmt.lower()) + inpt = convert_to_tv_tensor(inpt, key='boxes', box_format=self.fmt.upper(), spatial_size=spatial_size) + + if self.normalize: + inpt = inpt / torch.tensor(spatial_size[::-1]).tile(2)[None] + + return inpt @register() @@ -135,3 +153,16 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: inpt = Image(inpt) return inpt + + # Compatibility with torchvision>0.20.0 + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + inpt = F.pil_to_tensor(inpt) + if self.dtype == 'float32': + inpt = inpt.float() + + if self.scale: + inpt = inpt / 255. + + inpt = Image(inpt) + + return inpt