|
12 | 12 | import warnings
|
13 | 13 | from enum import Enum
|
14 | 14 | from functools import partial
|
15 |
| -from typing import TYPE_CHECKING |
| 15 | +from typing import TYPE_CHECKING, Callable |
16 | 16 |
|
17 | 17 | import cv2
|
18 | 18 | import numpy as np
|
@@ -86,15 +86,17 @@ class ResizeLib(Enum):
|
86 | 86 |
|
87 | 87 | def get_compatible_dtype(self, dtype: np.dtype | type) -> np.dtype:
|
88 | 88 | """Returns a suitable dtype with which the library can work. Warns if information loss could occur."""
|
| 89 | + lossless: dict[type, type] |
| 90 | + infoloss: dict[type, type] |
89 | 91 | if self is ResizeLib.CV2:
|
90 | 92 | lossless = {bool: np.uint8, np.float16: np.float32}
|
91 | 93 | infoloss = {x: np.int32 for x in (np.uint32, np.int64, np.uint64, int)}
|
92 | 94 | if self is ResizeLib.PIL:
|
93 | 95 | lossless = {np.float16: np.float32}
|
94 | 96 | infoloss = {x: np.int32 for x in (np.uint16, np.uint32, np.int64, np.uint64, int)}
|
95 | 97 |
|
96 |
| - lossless_casts = {np.dtype(k): np.dtype(v) for k, v in lossless.items()} |
97 |
| - infoloss_casts = {np.dtype(k): np.dtype(v) for k, v in infoloss.items()} |
| 98 | + lossless_casts: dict[np.dtype, np.dtype] = {np.dtype(k): np.dtype(v) for k, v in lossless.items()} |
| 99 | + infoloss_casts: dict[np.dtype, np.dtype] = {np.dtype(k): np.dtype(v) for k, v in infoloss.items()} |
98 | 100 | return self._extract_compatible_dtype(dtype, lossless_casts, infoloss_casts)
|
99 | 101 |
|
100 | 102 | @staticmethod
|
@@ -159,14 +161,11 @@ def spatially_resize_image(
|
159 | 161 | old_dtype, new_dtype = data.dtype, resize_library.get_compatible_dtype(data.dtype)
|
160 | 162 | data = data.astype(new_dtype)
|
161 | 163 |
|
| 164 | + resize_function: Callable[[np.ndarray], np.ndarray] |
162 | 165 | if resize_library is ResizeLib.CV2:
|
163 | 166 | resize_function = partial(cv2.resize, dsize=size, interpolation=resize_method.get_cv2_method(data.dtype))
|
164 | 167 | else:
|
165 |
| - resize_function = partial( |
166 |
| - _pil_resize_ndarray, # type: ignore[arg-type] |
167 |
| - size=size, |
168 |
| - method=resize_method.get_pil_method(), |
169 |
| - ) |
| 168 | + resize_function = partial(_pil_resize_ndarray, size=size, method=resize_method.get_pil_method()) |
170 | 169 |
|
171 | 170 | resized_data = _apply_to_spatial_axes(resize_function, data, spatial_axes)
|
172 | 171 |
|
|
0 commit comments