Skip to content

Commit 6a9a7af

Browse files
committed
Implement inferencing
1 parent 3468bc7 commit 6a9a7af

8 files changed

+120
-14
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -153,3 +153,6 @@ lightning_logs
153153

154154
# Weight&Biases Logs
155155
wandb/
156+
157+
# Checkpoints
158+
checkpoints

README.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,9 @@ Additionally, the following optional arguments can be supplied in the same confi
8888
"epochs": 50,
8989
"gpus": 1,
9090
"loss": "dice",
91-
"optimizer": "adam"
91+
"optimizer": "adam",
92+
"prediction_count": None,
93+
"prediction_dir": "./predictions"
9294
}
9395
```
9496

brats_example_config.json

+4-2
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,7 @@
88
"optimizer": "adam",
99
"loss": "dice",
1010
"data_dir": "/dhc/groups/mpws2021cl1/Data",
11-
"gpus": 1
12-
}
11+
"gpus": 1,
12+
"prediction_count": 5,
13+
"prediction_dir": "/dhc/groups/mpws2021cl1/Predictions"
14+
}

src/datasets/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
from datasets.data_module import ActiveLearningDataModule
33
from datasets.brats_data_module import BraTSDataModule
44
from datasets.pascal_voc_data_module import PascalVOCDataModule
5+
from datasets.brats_dataset import BraTSDataset

src/datasets/brats_data_module.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
""" Module containing the data module for brats data """
22
import os
3+
import random
34
from typing import Any, List, Optional, Union, Tuple
45
from torch.utils.data import Dataset
56

@@ -21,7 +22,9 @@ class BraTSDataModule(ActiveLearningDataModule):
2122
# pylint: disable=unused-argument,no-self-use
2223
@staticmethod
2324
def discover_paths(
24-
dir_path: str, modality: str = "flair"
25+
dir_path: str,
26+
modality: str = "flair",
27+
random_samples: Union[int, None] = None,
2528
) -> Tuple[List[str], List[str]]:
2629
"""
2730
Discover the .nii.gz file paths with a given modality
@@ -40,6 +43,9 @@ def discover_paths(
4043
if not case.startswith(".") and os.path.isdir(os.path.join(dir_path, case))
4144
]
4245

46+
if random_samples is not None and random_samples < len(cases):
47+
cases = random.sample(cases, random_samples)
48+
4349
image_paths = [
4450
os.path.join(dir_path, case, f"{os.path.basename(case)}_{modality}.nii.gz")
4551
for case in cases

src/datasets/brats_dataset.py

+17-10
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
""" Module to load and batch brats dataset """
2-
from typing import Any, Callable, List, Optional, Tuple
2+
from typing import Any, Callable, List, Literal, Optional, Tuple
33
import math
44
import nibabel as nib
55
import numpy as np
@@ -73,6 +73,7 @@ def __init__(
7373
clip_mask: bool = True,
7474
transform: Optional[Callable[[Any], torch.Tensor]] = None,
7575
target_transform: Optional[Callable[[Any], torch.Tensor]] = None,
76+
dimensionality: Literal["2d", "3d"] = "2d",
7677
):
7778

7879
self.image_paths = image_paths
@@ -98,16 +99,22 @@ def __init__(
9899
self.transform = transform
99100
self.target_transform = target_transform
100101

102+
self.dimensionality = dimensionality
103+
101104
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
102-
image_index = math.floor(index / BraTSDataset.IMAGE_DIMENSIONS[0])
103-
slice_index = index - image_index * BraTSDataset.IMAGE_DIMENSIONS[0]
104-
if image_index != self._current_image_index:
105-
self._current_image_index = image_index
106-
self._current_image = self.images[self._current_image_index]
107-
self._current_mask = self.masks[self._current_image_index]
108-
109-
x = torch.from_numpy(self._current_image[slice_index, :, :])
110-
y = torch.from_numpy(self._current_mask[slice_index, :, :])
105+
if self.dimensionality == "2d":
106+
image_index = math.floor(index / BraTSDataset.IMAGE_DIMENSIONS[0])
107+
slice_index = index - image_index * BraTSDataset.IMAGE_DIMENSIONS[0]
108+
if image_index != self._current_image_index:
109+
self._current_image_index = image_index
110+
self._current_image = self.images[self._current_image_index]
111+
self._current_mask = self.masks[self._current_image_index]
112+
113+
x = torch.from_numpy(self._current_image[slice_index, :, :])
114+
y = torch.from_numpy(self._current_mask[slice_index, :, :])
115+
else:
116+
x = torch.from_numpy(self.images[index])
117+
y = torch.from_numpy(self.masks[index])
111118

112119
if self.transform:
113120
x = self.transform(x)

src/inferencing.py

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
""" Module containing inferencing logic """
2+
import os
3+
import uuid
4+
import torch
5+
import numpy as np
6+
import nibabel as nib
7+
from models import PytorchModel
8+
from datasets import BraTSDataModule, BraTSDataset
9+
10+
11+
class Inferencer:
12+
"""
13+
The inferencer to use a given model for inferencing.
14+
Args:
15+
model: A model object to be used for inferencing.
16+
dataset: Name of the dataset. E.g. 'brats'
17+
data_dir: Main directory with the dataset. E.g. './data'
18+
prediction_dir: Main directory with the predictions. E.g. './predictions'
19+
prediction_count: The amount of predictions to be generated.
20+
"""
21+
22+
def __init__(
23+
self,
24+
model: PytorchModel,
25+
dataset: str,
26+
data_dir: str,
27+
prediction_dir: str,
28+
prediction_count: int,
29+
) -> None:
30+
self.model = model
31+
self.dataset = dataset
32+
self.data_dir = data_dir
33+
self.prediction_dir = prediction_dir
34+
self.prediction_count = prediction_count
35+
36+
def inference(self) -> None:
37+
"""Run the inferencing."""
38+
if not self.dataset == "brats":
39+
print(f"Inferencing is not implemented for the {self.dataset} dataset.")
40+
return
41+
42+
output_folder_name = f"model-{str(uuid.uuid4())}"
43+
output_dir = os.path.join(self.prediction_dir, output_folder_name)
44+
os.mkdir(output_dir)
45+
46+
image_paths, annotation_paths = BraTSDataModule.discover_paths(
47+
self.data_dir,
48+
random_samples=self.prediction_count,
49+
)
50+
data = BraTSDataset(
51+
image_paths=image_paths,
52+
annotation_paths=annotation_paths,
53+
dimensionality="3d",
54+
)
55+
56+
for i in range(self.prediction_count):
57+
x, _ = data.__getitem__(i)
58+
59+
x = torch.swapaxes(x, 0, 1)
60+
pred = self.model(x)
61+
pred = torch.swapaxes(pred, 0, 1)
62+
63+
seg = pred.detach().cpu().numpy()[0]
64+
seg = (seg >= 0.5) * 255
65+
seg = np.moveaxis(seg, 0, 2)
66+
seg = np.rot90(seg, 2, (0, 1))
67+
seg = seg.astype("float64")
68+
69+
img = nib.Nifti1Image(seg, np.eye(4))
70+
file_name = os.path.basename(annotation_paths[i]).replace("seg", "pred")
71+
path = os.path.join(output_dir, file_name)
72+
nib.save(img, path)
73+
print(path)

src/main.py

+12
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
""" Main module to execute active learning pipeline from CLI """
22
import json
33
import os.path
4+
from typing import Union
45
import fire
56
from active_learning import ActiveLearningPipeline
7+
from inferencing import Inferencer
68
from models import PytorchFCNResnet50, PytorchUNet
79
from datasets import BraTSDataModule, PascalVOCDataModule
810
from query_strategies import QueryStrategy
@@ -20,6 +22,8 @@ def run_active_learning_pipeline(
2022
gpus: int = 1,
2123
loss: str = "dice",
2224
optimizer: str = "adam",
25+
prediction_count: Union[int, None] = None,
26+
prediction_dir: str = "./predictions",
2327
) -> None:
2428
"""
2529
Main function to execute an active learning pipeline run, or start an active learning simulation.
@@ -60,6 +64,14 @@ def run_active_learning_pipeline(
6064
pipeline = ActiveLearningPipeline(data_module, model, strategy, epochs, gpus)
6165
pipeline.run()
6266

67+
if prediction_count is None:
68+
return
69+
70+
inferencer = Inferencer(
71+
model, dataset, os.path.join(data_dir, "val"), prediction_dir, prediction_count
72+
)
73+
inferencer.inference()
74+
6375

6476
def run_active_learning_pipeline_from_config(config_file_name: str) -> None:
6577
"""

0 commit comments

Comments
 (0)