11""" Module to load and batch brats dataset """
2- from typing import Any , Callable , List , Optional , Tuple
2+ from typing import Any , Callable , List , Literal , Optional , Tuple
33import math
44import nibabel as nib
55import numpy as np
@@ -73,6 +73,7 @@ def __init__(
7373 clip_mask : bool = True ,
7474 transform : Optional [Callable [[Any ], torch .Tensor ]] = None ,
7575 target_transform : Optional [Callable [[Any ], torch .Tensor ]] = None ,
76+ dimensionality : Literal ["2d" , "3d" ] = "2d" ,
7677 ):
7778
7879 self .image_paths = image_paths
@@ -98,16 +99,22 @@ def __init__(
9899 self .transform = transform
99100 self .target_transform = target_transform
100101
102+ self .dimensionality = dimensionality
103+
101104 def __getitem__ (self , index : int ) -> Tuple [torch .Tensor , torch .Tensor ]:
102- image_index = math .floor (index / BraTSDataset .IMAGE_DIMENSIONS [0 ])
103- slice_index = index - image_index * BraTSDataset .IMAGE_DIMENSIONS [0 ]
104- if image_index != self ._current_image_index :
105- self ._current_image_index = image_index
106- self ._current_image = self .images [self ._current_image_index ]
107- self ._current_mask = self .masks [self ._current_image_index ]
108-
109- x = torch .from_numpy (self ._current_image [slice_index , :, :])
110- y = torch .from_numpy (self ._current_mask [slice_index , :, :])
105+ if self .dimensionality == "2d" :
106+ image_index = math .floor (index / BraTSDataset .IMAGE_DIMENSIONS [0 ])
107+ slice_index = index - image_index * BraTSDataset .IMAGE_DIMENSIONS [0 ]
108+ if image_index != self ._current_image_index :
109+ self ._current_image_index = image_index
110+ self ._current_image = self .images [self ._current_image_index ]
111+ self ._current_mask = self .masks [self ._current_image_index ]
112+
113+ x = torch .from_numpy (self ._current_image [slice_index , :, :])
114+ y = torch .from_numpy (self ._current_mask [slice_index , :, :])
115+ else :
116+ x = torch .from_numpy (self .images [index ])
117+ y = torch .from_numpy (self .masks [index ])
111118
112119 if self .transform :
113120 x = self .transform (x )
0 commit comments