diff --git a/engine/data/transforms/_transforms.py b/engine/data/transforms/_transforms.py index 31588df5..86e4b1dd 100644 --- a/engine/data/transforms/_transforms.py +++ b/engine/data/transforms/_transforms.py @@ -71,6 +71,12 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: padding = params['padding'] return F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type] + # added override for torchvision >=0.21 + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + fill = self._fill[type(inpt)] + padding = params['padding'] + return F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type] + def __call__(self, *inputs: Any) -> Any: outputs = super().forward(*inputs) if len(outputs) > 1 and isinstance(outputs[1], dict): @@ -113,6 +119,19 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return inpt + # added override for torchvision >=0.21 + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + spatial_size = getattr(inpt, _boxes_keys[1]) + if self.fmt: + in_fmt = inpt.format.value.lower() + inpt = torchvision.ops.box_convert(inpt, in_fmt=in_fmt, out_fmt=self.fmt.lower()) + inpt = convert_to_tv_tensor(inpt, key='boxes', box_format=self.fmt.upper(), spatial_size=spatial_size) + + if self.normalize: + inpt = inpt / torch.tensor(spatial_size[::-1]).tile(2)[None] + + return inpt + @register() class ConvertPILImage(T.Transform): @@ -135,3 +154,16 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: inpt = Image(inpt) return inpt + + # added override for torchvision >=0.21 + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + inpt = F.pil_to_tensor(inpt) + if self.dtype == 'float32': + inpt = inpt.float() + + if self.scale: + inpt = inpt / 255. + + inpt = Image(inpt) + + return inpt