1
1
""" Module to load and batch brats dataset """
2
- from typing import Any , Callable , List , Optional , Tuple
2
+ from typing import Any , Callable , List , Literal , Optional , Tuple
3
3
import math
4
4
import nibabel as nib
5
5
import numpy as np
@@ -73,6 +73,7 @@ def __init__(
73
73
clip_mask : bool = True ,
74
74
transform : Optional [Callable [[Any ], torch .Tensor ]] = None ,
75
75
target_transform : Optional [Callable [[Any ], torch .Tensor ]] = None ,
76
+ dimensionality : Literal ["2d" , "3d" ] = "2d" ,
76
77
):
77
78
78
79
self .image_paths = image_paths
@@ -98,16 +99,22 @@ def __init__(
98
99
self .transform = transform
99
100
self .target_transform = target_transform
100
101
102
+ self .dimensionality = dimensionality
103
+
101
104
def __getitem__ (self , index : int ) -> Tuple [torch .Tensor , torch .Tensor ]:
102
- image_index = math .floor (index / BraTSDataset .IMAGE_DIMENSIONS [0 ])
103
- slice_index = index - image_index * BraTSDataset .IMAGE_DIMENSIONS [0 ]
104
- if image_index != self ._current_image_index :
105
- self ._current_image_index = image_index
106
- self ._current_image = self .images [self ._current_image_index ]
107
- self ._current_mask = self .masks [self ._current_image_index ]
108
-
109
- x = torch .from_numpy (self ._current_image [slice_index , :, :])
110
- y = torch .from_numpy (self ._current_mask [slice_index , :, :])
105
+ if self .dimensionality == "2d" :
106
+ image_index = math .floor (index / BraTSDataset .IMAGE_DIMENSIONS [0 ])
107
+ slice_index = index - image_index * BraTSDataset .IMAGE_DIMENSIONS [0 ]
108
+ if image_index != self ._current_image_index :
109
+ self ._current_image_index = image_index
110
+ self ._current_image = self .images [self ._current_image_index ]
111
+ self ._current_mask = self .masks [self ._current_image_index ]
112
+
113
+ x = torch .from_numpy (self ._current_image [slice_index , :, :])
114
+ y = torch .from_numpy (self ._current_mask [slice_index , :, :])
115
+ else :
116
+ x = torch .from_numpy (self .images [index ])
117
+ y = torch .from_numpy (self .masks [index ])
111
118
112
119
if self .transform :
113
120
x = self .transform (x )
0 commit comments