diff --git a/ICON/LICENSE b/ICON/LICENSE new file mode 100644 index 00000000..db8ea3b3 --- /dev/null +++ b/ICON/LICENSE @@ -0,0 +1,13 @@ +Copyright 2017, Hastings Greer + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/ICON/icon_registration/__init__.py b/ICON/icon_registration/__init__.py new file mode 100644 index 00000000..5ed2873b --- /dev/null +++ b/ICON/icon_registration/__init__.py @@ -0,0 +1,34 @@ +from icon_registration.losses import ( + GradICON, + ICON, + BendingEnergy, + Diffusion, + VelocityFieldBendingEnergy, + VelocityFieldDiffusion, +) + +from icon_registration.similarity import ( + LNCC, + LNCCOnlyInterpolated, + BlurredSSD, + SSDOnlyInterpolated, + SSD, + NCC, +) +from icon_registration.network_wrappers import ( + InverseConsistentAffine, + InverseConsistentVelocityField, + Downsample, + TwoStep, + Affine, + VelocityField, + DisplacementField, +) +from icon_registration.train import train_batchfunction, train_datasets + + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from icon_registration.registration_module import RegistrationModule diff --git a/ICON/icon_registration/config.py b/ICON/icon_registration/config.py new file mode 100644 index 00000000..bc7c42c0 --- /dev/null +++ b/ICON/icon_registration/config.py @@ -0,0 +1,6 @@ +import torch + +if torch.cuda.is_available(): + device = torch.device("cuda") +else: + device = torch.device("cpu") diff --git a/ICON/icon_registration/data.py b/ICON/icon_registration/data.py new file mode 100644 index 00000000..6905232b --- /dev/null +++ b/ICON/icon_registration/data.py @@ -0,0 +1,480 @@ +import random + +import numpy as np +import torch +import torch.nn.functional as F +import torchvision +import tqdm + +from icon_registration import config + + +def get_dataset_mnist(split, number=5): + ds = torch.utils.data.DataLoader( + torchvision.datasets.MNIST( + "./files/", + transform=torchvision.transforms.ToTensor(), + download=True, + train=(split == "train"), + ), + batch_size=500, + ) + images = [] + for _, batch in enumerate(ds): + label = np.array(batch[1]) + batch_nines = label == number + images.append(np.array(batch[0])[batch_nines]) + images = np.concatenate(images) + + ds = torch.utils.data.TensorDataset(torch.Tensor(images)) + d1, d2 = ( + torch.utils.data.DataLoader( + ds, + batch_size=128, + shuffle=True, + ) + for _ in (1, 1) + ) + return d1, d2 + + +def get_dataset_1d(data_size=128, samples=6000, batch_size=128): + x = np.arange(0, 1, 1 / data_size) + x = np.reshape(x, (1, data_size)) + cx = np.random.random((samples, 1)) * 0.3 + 0.4 + r = np.random.random((samples, 1)) * 0.2 + 0.2 + + circles = np.tanh(-40 * (np.sqrt((x - cx) ** 2) - r)) + + ds = torch.utils.data.TensorDataset(torch.Tensor(np.expand_dims(circles, 1))) + d1, d2 = ( + torch.utils.data.DataLoader( + ds, + batch_size=batch_size, + shuffle=True, + ) + for _ in (1, 1) + ) + return d1, d2 + + +def get_dataset_triangles( + split=None, data_size=128, hollow=False, samples=6000, batch_size=128 +): + x, y = np.mgrid[0 : 1 : data_size * 1j, 0 : 1 : data_size * 1j] + x = np.reshape(x, (1, data_size, data_size)) + y = np.reshape(y, (1, data_size, data_size)) + cx = np.random.random((samples, 1, 1)) * 0.3 + 0.4 + cy = np.random.random((samples, 1, 1)) * 0.3 + 0.4 + r = np.random.random((samples, 1, 1)) * 0.2 + 0.2 + theta = np.random.random((samples, 1, 1)) * np.pi * 2 + isTriangle = np.random.random((samples, 1, 1)) > 0.5 + + triangles = np.sqrt((x - cx) ** 2 + (y - cy) ** 2) - r * np.cos(np.pi / 3) / np.cos( + (np.arctan2(x - cx, y - cy) + theta) % (2 * np.pi / 3) - np.pi / 3 + ) + + triangles = np.tanh(-40 * triangles) + + circles = np.tanh(-40 * (np.sqrt((x - cx) ** 2 + (y - cy) ** 2) - r)) + if hollow: + triangles = 1 - triangles**2 + circles = 1 - circles**2 + + images = isTriangle * triangles + (1 - isTriangle) * circles + + ds = torch.utils.data.TensorDataset(torch.Tensor(np.expand_dims(images, 1))) + d1, d2 = ( + torch.utils.data.DataLoader( + ds, + batch_size=batch_size, + shuffle=True, + ) + for _ in (1, 1) + ) + return d1, d2 + + +def get_dataset_retina( + extra_deformation=False, + downsample_factor=4, + blur_sigma=None, + warps_per_pair=20, + fixed_vertical_offset=None, + include_boundary=False, +): + try: + import elasticdeform + import hub + except: + raise Exception( + """the retina dataset requires the dependencies hub and elasticdeform. + Try pip install hub elasticdeform""" + ) + + ds_name = f"retina{extra_deformation}{downsample_factor}{blur_sigma}{warps_per_pair}{fixed_vertical_offset}{include_boundary}.trch" + + import os + + if os.path.exists(ds_name): + augmented_ds1_tensor, augmented_ds2_tensor = torch.load(ds_name) + else: + res = [] + for batch in hub.load("hub://activeloop/drive-train").pytorch( + num_workers=0, batch_size=4, shuffle=False + ): + if include_boundary: + res.append(batch["manual_masks/mask"] ^ batch["masks/mask"]) + else: + res.append(batch["manual_masks/mask"]) + res = torch.cat(res) + ds_tensor = res[:, None, :, :, 0] * -1.0 + (not include_boundary) + + if fixed_vertical_offset is not None: + ds2_tensor = torch.cat( + [torch.zeros(20, 1, fixed_vertical_offset, 565), ds_tensor], axis=2 + ) + ds1_tensor = torch.cat( + [ds_tensor, torch.zeros(20, 1, fixed_vertical_offset, 565)], axis=2 + ) + else: + ds2_tensor = ds_tensor + ds1_tensor = ds_tensor + + warped_tensors = [] + print("warping images to generate dataset") + for _ in tqdm.tqdm(range(warps_per_pair)): + ds_2_list = [] + for el in ds2_tensor: + case = el[0] + # TODO implement random warping on gpu + case_warped = np.array(case) + if extra_deformation: + case_warped = elasticdeform.deform_random_grid( + case_warped, sigma=60, points=3 + ) + case_warped = elasticdeform.deform_random_grid( + case_warped, sigma=25, points=3 + ) + + case_warped = elasticdeform.deform_random_grid( + case_warped, sigma=12, points=6 + ) + ds_2_list.append(torch.tensor(case_warped)[None, None, :, :]) + ds_2_tensor = torch.cat(ds_2_list) + warped_tensors.append(ds_2_tensor) + + augmented_ds2_tensor = torch.cat(warped_tensors) + augmented_ds1_tensor = torch.cat([ds1_tensor for _ in range(warps_per_pair)]) + + torch.save((augmented_ds1_tensor, augmented_ds2_tensor), ds_name) + + batch_size = 10 + import torchvision.transforms.functional as Fv + + if blur_sigma is None: + ds1 = torch.utils.data.TensorDataset( + F.avg_pool2d(augmented_ds1_tensor, downsample_factor) + ) + else: + ds1 = torch.utils.data.TensorDataset( + Fv.gaussian_blur( + F.avg_pool2d(augmented_ds1_tensor, downsample_factor), + 4 * blur_sigma + 1, + blur_sigma, + ) + ) + d1 = torch.utils.data.DataLoader( + ds1, + batch_size=batch_size, + shuffle=False, + ) + if blur_sigma is None: + ds2 = torch.utils.data.TensorDataset( + F.avg_pool2d(augmented_ds2_tensor, downsample_factor) + ) + else: + ds2 = torch.utils.data.TensorDataset( + Fv.gaussian_blur( + F.avg_pool2d(augmented_ds2_tensor, downsample_factor), + 4 * blur_sigma + 1, + blur_sigma, + ) + ) + + d2 = torch.utils.data.DataLoader( + ds2, + batch_size=batch_size, + shuffle=False, + ) + + return d1, d2 + + +def get_dataset_sunnyside(split, scale=1): + import pickle + + with open("/playpen/tgreer/sunnyside.pickle", "rb") as f: + array = pickle.load(f) + if split == "train": + array = array[1000:] + elif split == "test": + array = array[:1000] + else: + raise ArgumentError() + + array = array[:, :, :, 0] + array = np.expand_dims(array, 1) + array = array * scale + array1 = array[::2] + array2 = array[1::2] + array12 = np.concatenate([array2, array1]) + array21 = np.concatenate([array1, array2]) + ds = torch.utils.data.TensorDataset(torch.Tensor(array21), torch.Tensor(array12)) + ds = torch.utils.data.DataLoader( + ds, + batch_size=128, + shuffle=True, + ) + return ds + + +def get_cartilage_dataset(): + cartilage = torch.load("/playpen/tgreer/cartilage_uint8s.trch") + return cartilage + + +def get_knees_dataset(): + brains = torch.load("/playpen/tgreer/kneestorch") + # with open("/playpen/tgreer/cartilage_eval_oriented", "rb") as f: + # cartilage = pickle.load(f) + + medbrains = [] + for b in brains: + medbrains.append(F.avg_pool3d(b, 4)) + + return brains, medbrains + + +def get_copdgene_dataset( + data_folder, cache_folder="./data_cache", lung_only=True, downscale=2 +): + """ + This function load the preprocessed COPDGene train set. + """ + import os + + def process(iA, downscale, clamp=[-1000, 0], isSeg=False): + iA = iA[None, None, :, :, :] + # SI flip + iA = torch.flip(iA, dims=(2,)) + if isSeg: + iA = iA.float() + iA = torch.nn.functional.max_pool3d(iA, downscale) + iA[iA > 0] = 1 + else: + iA = torch.clip(iA, clamp[0], clamp[1]) + clamp[0] + # TODO: For compatibility to the processed dataset(ranges between -1 to 0) used in paper, we subtract -1 here. + # Should remove -1 later. + iA = iA / torch.max(iA) - 1.0 + iA = torch.nn.functional.avg_pool3d(iA, downscale) + return iA + + cache_name = f"{cache_folder}/lungs_train_{downscale}xdown_scaled" + if os.path.exists(cache_name): + imgs = torch.load(cache_name, map_location="cpu") + if lung_only: + try: + masks = torch.load( + f"{cache_folder}/lungs_seg_train_{downscale}xdown_scaled", + map_location="cpu", + ) + except FileNotFoundError: + print("Segmentation data not found.") + + else: + import itk + import glob + + with open(f"{data_folder}/splits/train.txt") as f: + pair_paths = f.readlines() + imgs = [] + masks = [] + for name in tqdm.tqdm(list(iter(pair_paths))[:]): + name = name[:-1] # remove newline + + image_insp = torch.tensor( + np.asarray( + itk.imread( + glob.glob( + f"{data_folder} /{name}/{name}_INSP_STD*_COPD_img.nii.gz" + )[0] + ) + ) + ) + image_exp = torch.tensor( + np.asarray( + itk.imread( + glob.glob( + f"{data_folder} /{name}/{name}_EXP_STD*_COPD_img.nii.gz" + )[0] + ) + ) + ) + imgs.append((process(image_insp), process(image_exp))) + + seg_insp = torch.tensor( + np.asarray( + itk.imread( + glob.glob( + f"{data_folder} /{name}/{name}_INSP_STD*_COPD_label.nii.gz" + )[0] + ) + ) + ) + seg_exp = torch.tensor( + np.asarray( + itk.imread( + glob.glob( + f"{data_folder} /{name}/{name}_EXP_STD*_COPD_label.nii.gz" + )[0] + ) + ) + ) + masks.append((process(seg_insp, True), process(seg_exp, True))) + + torch.save(imgs, f"{cache_folder}/lungs_train_{downscale}xdown_scaled") + torch.save(masks, f"{cache_folder}/lungs_seg_train_{downscale}xdown_scaled") + + if lung_only: + imgs = torch.cat( + [(torch.cat(d, 1) + 1) * torch.cat(m, 1) for d, m in zip(imgs, masks)], + dim=0, + ) + else: + imgs = torch.cat([torch.cat(d, 1) + 1 for d in imgs], dim=0) + return torch.utils.data.TensorDataset(imgs) + + +def get_learn2reg_AbdomenCTCT_dataset( + data_folder, cache_folder="./data_cache", clamp=[-1000, 0], downscale=1 +): + """ + This function will return the training dataset of AbdomenCTCT registration task in learn2reg. + """ + + # Check whether we have cached the dataset + import os + + cache_name = ( + f"{cache_folder}/learn2reg_abdomenctct_train_set_clamp{clamp}scale{downscale}" + ) + if os.path.exists(cache_name): + imgs = torch.load(cache_name) + else: + import json + import itk + import glob + + with open(f"{data_folder}/AbdomenCTCT_dataset.json", "r") as data_info: + data_info = json.loads(data_info.read()) + train_cases = [ + c["image"].split("/")[-1].split(".")[0] for c in data_info["training"] + ] + imgs = [ + np.asarray( + itk.imread(glob.glob(data_folder + "/imagesTr/" + i + ".nii.gz")[0]) + ) + for i in train_cases + ] + + imgs = torch.Tensor(np.expand_dims(np.array(imgs), axis=1)).float() + imgs = (torch.clamp(imgs, clamp[0], clamp[1]) - clamp[0]) / ( + clamp[1] - clamp[0] + ) + + # Cache the data + if not os.path.exists(cache_folder): + os.makedirs(cache_folder) + torch.save(imgs, cache_name) + + # Scale down the image + if downscale > 1: + imgs = F.avg_pool3d(imgs, downscale) + return torch.utils.data.TensorDataset(imgs) + + +def get_learn2reg_lungCT_dataset( + data_folder, + cache_folder="./data_cache", + lung_only=True, + clamp=[-1000, 0], + downscale=1, +): + """ + This function will return the training dataset of LungCT registration task in learn2reg. + """ + import os + + cache_name = ( + f"{cache_folder}/learn2reg_lung_train_set_lung_only" + if lung_only + else f"{cache_folder}/learn2reg_lung_train_set" + ) + cache_name += f"_clamp{clamp}scale{downscale}" + if os.path.exists(cache_name): + imgs = torch.load(cache_name) + else: + import json + import itk + import glob + + with open(f"{data_folder}/NLST_dataset.json", "r") as data_info: + data_info = json.loads(data_info.read()) + train_pairs = [ + [p["fixed"].split("/")[-1], p["moving"].split("/")[-1]] + for p in data_info["training_paired_images"] + ] + imgs = [] + for p in train_pairs: + img = np.array( + [ + np.asarray(itk.imread(glob.glob(data_folder + "/imagesTr/" + i)[0])) + for i in p + ] + ) + if lung_only: + mask = np.array( + [ + np.asarray( + itk.imread( + glob.glob(data_folder + "/" + "/masksTr/" + i)[0] + ) + ) + for i in p + ] + ) + img = img * mask + clamp[0] * (1 - mask) + imgs.append(img) + + imgs = torch.Tensor(np.array(imgs)).float() + imgs = (torch.clamp(imgs, clamp[0], clamp[1]) - clamp[0]) / ( + clamp[1] - clamp[0] + ) + + # Cache the data + if not os.path.exists(cache_folder): + os.makedirs(cache_folder) + torch.save(imgs, cache_name) + + # Scale down the image + if downscale > 1: + imgs = F.avg_pool3d(imgs, downscale) + return torch.utils.data.TensorDataset(imgs) + + +def make_batch(data, BATCH_SIZE, SCALE): + image = torch.cat([random.choice(data) for _ in range(BATCH_SIZE)]) + image = image.reshape(BATCH_SIZE, 1, SCALE * 40, SCALE * 96, SCALE * 96) + image = image.to(config.device) + return image diff --git a/ICON/icon_registration/itk_wrapper.py b/ICON/icon_registration/itk_wrapper.py new file mode 100644 index 00000000..5559cb6b --- /dev/null +++ b/ICON/icon_registration/itk_wrapper.py @@ -0,0 +1,233 @@ +import copy + +import itk +import numpy as np +import torch +import torch.nn.functional as F + +from icon_registration import config +from icon_registration.losses import to_floats + + +def finetune_execute(model, image_A, image_B, steps): + state_dict = copy.deepcopy(model.state_dict()) + optimizer = torch.optim.Adam(model.parameters(), lr=0.00002) + for _ in range(steps): + optimizer.zero_grad() + loss_tuple = model(image_A, image_B) + print(loss_tuple) + loss_tuple[0].backward() + optimizer.step() + with torch.no_grad(): + loss = model(image_A, image_B) + model.load_state_dict(state_dict) + return loss + + +def register_pair( + model, image_A, image_B, finetune_steps=None, return_artifacts=False +) -> "(itk.CompositeTransform, itk.CompositeTransform)": + assert isinstance(image_A, itk.Image) + assert isinstance(image_B, itk.Image) + + # send model to cpu or gpu depending on config- auto detects capability + model.to(config.device) + + A_npy = np.array(image_A) + B_npy = np.array(image_B) + + assert np.max(A_npy) != np.min(A_npy) + assert np.max(B_npy) != np.min(B_npy) + # turn images into torch Tensors: add feature and batch dimensions (each of length 1) + A_trch = torch.Tensor(A_npy).to(config.device)[None, None] + B_trch = torch.Tensor(B_npy).to(config.device)[None, None] + + shape = model.identity_map.shape + + # Here we resize the input images to the shape expected by the neural network. This affects the + # pixel stride as well as the magnitude of the displacement vectors of the resulting + # displacement field, which create_itk_transform will have to compensate for. + A_resized = F.interpolate( + A_trch, size=shape[2:], mode="trilinear", align_corners=False + ) + B_resized = F.interpolate( + B_trch, size=shape[2:], mode="trilinear", align_corners=False + ) + if finetune_steps == 0: + raise Exception("To indicate no finetune_steps, pass finetune_steps=None") + + if finetune_steps == None: + with torch.no_grad(): + loss = model(A_resized, B_resized) + else: + loss = finetune_execute(model, A_resized, B_resized, finetune_steps) + + # phi_AB and phi_BA are [1, 3, H, W, D] pytorch tensors representing the forward and backward + # maps computed by the model + if hasattr(model, "prepare_for_viz"): + with torch.no_grad(): + model.prepare_for_viz(A_resized, B_resized) + phi_AB = model.phi_AB(model.identity_map) + phi_BA = model.phi_BA(model.identity_map) + + # the parameters ident, image_A, and image_B are used for their metadata + itk_transforms = ( + create_itk_transform(phi_AB, model.identity_map, image_A, image_B), + create_itk_transform(phi_BA, model.identity_map, image_B, image_A), + ) + if not return_artifacts: + return itk_transforms + else: + return itk_transforms + (to_floats(loss),) + + +def register_pair_with_multimodalities( + model, image_A: list, image_B: list, finetune_steps=None, return_artifacts=False +) -> "(itk.CompositeTransform, itk.CompositeTransform)": + assert len(image_A) == len( + image_B + ), "image_A and image_B should have the same number of modalities." + + # send model to cpu or gpu depending on config- auto detects capability + model.to(config.device) + + A_npy, B_npy = [], [] + for image_a, image_b in zip(image_A, image_B): + assert isinstance(image_a, itk.Image) + assert isinstance(image_b, itk.Image) + + A_npy.append(np.array(image_a)) + B_npy.append(np.array(image_b)) + + assert np.max(A_npy[-1]) != np.min(A_npy[-1]) + assert np.max(B_npy[-1]) != np.min(B_npy[-1]) + + # turn images into torch Tensors: add batch dimensions (each of length 1) + A_trch = torch.Tensor(np.array(A_npy)).to(config.device)[None] + B_trch = torch.Tensor(np.array(B_npy)).to(config.device)[None] + + shape = model.identity_map.shape[2:] + if list(A_trch.shape[2:]) != list(shape) or (list(B_trch.shape[2:]) != list(shape)): + # Here we resize the input images to the shape expected by the neural network. This affects the + # pixel stride as well as the magnitude of the displacement vectors of the resulting + # displacement field, which create_itk_transform will have to compensate for. + A_trch = F.interpolate( + A_trch, size=shape, mode="trilinear", align_corners=False + ) + B_trch = F.interpolate( + B_trch, size=shape, mode="trilinear", align_corners=False + ) + + if finetune_steps == 0: + raise Exception("To indicate no finetune_steps, pass finetune_steps=None") + + if finetune_steps == None: + with torch.no_grad(): + loss = model(A_trch, B_trch) + else: + loss = finetune_execute(model, A_trch, B_trch, finetune_steps) + + # phi_AB and phi_BA are [1, 3, H, W, D] pytorch tensors representing the forward and backward + # maps computed by the model + if hasattr(model, "prepare_for_viz"): + with torch.no_grad(): + model.prepare_for_viz(A_trch, B_trch) + phi_AB = model.phi_AB(model.identity_map) + phi_BA = model.phi_BA(model.identity_map) + + # the parameters ident, image_A, and image_B are used for their metadata + itk_transforms = ( + create_itk_transform(phi_AB, model.identity_map, image_A[0], image_B[0]), + create_itk_transform(phi_BA, model.identity_map, image_B[0], image_A[0]), + ) + if not return_artifacts: + return itk_transforms + else: + return itk_transforms + (to_floats(loss),) + + +def create_itk_transform(phi, ident, image_A, image_B) -> "itk.CompositeTransform": + # itk.DeformationFieldTransform expects a displacement field, so we subtract off the identity map. + disp = (phi - ident)[0].cpu() + + network_shape_list = list(ident.shape[2:]) + + dimension = len(network_shape_list) + + tr = itk.DisplacementFieldTransform[(itk.D, dimension)].New() + + # We convert the displacement field into an itk Vector Image. + scale = torch.Tensor(network_shape_list) + + for _ in network_shape_list: + scale = scale[:, None] + disp *= scale - 1 + + # disp is a shape [3, H, W, D] tensor with vector components in the order [vi, vj, vk] + disp_itk_format = ( + disp.double() + .numpy()[list(reversed(range(dimension)))] + .transpose(list(range(1, dimension + 1)) + [0]) + ) + # disp_itk_format is a shape [H, W, D, 3] array with vector components in the order [vk, vj, vi] + # as expected by itk. + + itk_disp_field = itk.image_from_array(disp_itk_format, is_vector=True) + + tr.SetDisplacementField(itk_disp_field) + + to_network_space = resampling_transform(image_A, list(reversed(network_shape_list))) + + from_network_space = resampling_transform( + image_B, list(reversed(network_shape_list)) + ).GetInverseTransform() + + phi_AB_itk = itk.CompositeTransform[itk.D, dimension].New() + + phi_AB_itk.PrependTransform(from_network_space) + phi_AB_itk.PrependTransform(tr) + phi_AB_itk.PrependTransform(to_network_space) + + # warp(image_A, phi_AB_itk) is close to image_B + + return phi_AB_itk + + +def resampling_transform(image, shape): + imageType = itk.template(image)[0][itk.template(image)[1]] + + dummy_image = itk.image_from_array( + np.zeros(tuple(reversed(shape)), dtype=itk.array_from_image(image).dtype) + ) + if len(shape) == 2: + transformType = itk.MatrixOffsetTransformBase[itk.D, 2, 2] + else: + transformType = itk.VersorRigid3DTransform[itk.D] + initType = itk.CenteredTransformInitializer[transformType, imageType, imageType] + initializer = initType.New() + initializer.SetFixedImage(dummy_image) + initializer.SetMovingImage(image) + transform = transformType.New() + + initializer.SetTransform(transform) + initializer.InitializeTransform() + + if len(shape) == 3: + transformType = itk.CenteredAffineTransform[itk.D, 3] + t2 = transformType.New() + t2.SetCenter(transform.GetCenter()) + t2.SetOffset(transform.GetOffset()) + transform = t2 + m = transform.GetMatrix() + m_a = itk.array_from_matrix(m) + + input_shape = image.GetLargestPossibleRegion().GetSize() + + for i in range(len(shape)): + m_a[i, i] = image.GetSpacing()[i] * (input_shape[i] / shape[i]) + + m_a = itk.array_from_matrix(image.GetDirection()) @ m_a + + transform.SetMatrix(itk.matrix_from_array(m_a)) + + return transform diff --git a/ICON/icon_registration/losses.py b/ICON/icon_registration/losses.py new file mode 100644 index 00000000..f365ac89 --- /dev/null +++ b/ICON/icon_registration/losses.py @@ -0,0 +1,426 @@ +from collections import namedtuple + +import matplotlib +import torch +import torch.nn.functional as F + +from icon_registration import config, network_wrappers, registration_module + +def flips(phi, in_percentage=False): + if len(phi.size()) == 5: + a = (phi[:, :, 1:, 1:, 1:] - phi[:, :, :-1, 1:, 1:]).detach() + b = (phi[:, :, 1:, 1:, 1:] - phi[:, :, 1:, :-1, 1:]).detach() + c = (phi[:, :, 1:, 1:, 1:] - phi[:, :, 1:, 1:, :-1]).detach() + + dV = torch.sum(torch.cross(a, b, 1) * c, axis=1, keepdims=True) + if in_percentage: + return torch.mean((dV < 0).float()) * 100. + else: + return torch.sum(dV < 0) / phi.shape[0] + elif len(phi.size()) == 4: + du = (phi[:, :, 1:, :-1] - phi[:, :, :-1, :-1]).detach() + dv = (phi[:, :, :-1, 1:] - phi[:, :, :-1, :-1]).detach() + dA = du[:, 0] * dv[:, 1] - du[:, 1] * dv[:, 0] + if in_percentage: + return torch.mean((dA < 0).float()) * 100. + else: + return torch.sum(dA < 0) / phi.shape[0] + elif len(phi.size()) == 3: + du = (phi[:, :, 1:] - phi[:, :, :-1]).detach() + if in_percentage: + return torch.mean((du < 0).float()) * 100. + else: + return torch.sum(du < 0) / phi.shape[0] + else: + raise ValueError() + +def to_floats(result_dictionary): + """ + Calling forward on the modules in this file returns a rich result dictionary of differentiable torch objects. This function + converts those to scalars suitable for logging in e.g. tensorboard + """ + + out = {} + for key, value in result_dictionary.items(): + if isinstance(value, float) or isinstance(value, int): + out[key] = value + elif isinstance(value, torch.Tensor) and value.shape == (): + out[key] = value.cpu().item() + return out + + +class Loss(registration_module.RegistrationModule): + def __init__(self, network, similarity, lmbda): + super().__init__() + self.regis_net = network + self.lmbda = lmbda + self.similarity = similarity + + def compute_similarity(self, image_A, image_B, phi_AB_vectorfield): + if getattr(self.similarity, "isInterpolated", False): + # tag images during warping so that the similarity measure + # can use information about whether a sample is interpolated + # or extrapolated + inbounds_tag = torch.zeros( + [image_A.shape[0]] + [1] + list(image_A.shape[2:]), + device=image_A.device, + ) + if len(self.input_shape) - 2 == 3: + inbounds_tag[:, :, 1:-1, 1:-1, 1:-1] = 1.0 + elif len(self.input_shape) - 2 == 2: + inbounds_tag[:, :, 1:-1, 1:-1] = 1.0 + else: + inbounds_tag[:, :, 1:-1] = 1.0 + else: + inbounds_tag = None + + warped_image_A = self.as_function( + torch.cat([image_A, inbounds_tag], axis=1) + if inbounds_tag is not None + else image_A + )( + phi_AB_vectorfield, + ) + similarity_loss = self.similarity(warped_image_A, image_B) + return {"similarity_loss": similarity_loss, "warped_image_A": warped_image_A} + + +class TwoWayRegularizer(Loss): + def forward(self, image_A, image_B): + assert self.identity_map.shape[2:] == image_A.shape[2:] + assert self.identity_map.shape[2:] == image_B.shape[2:] + + # Tag used elsewhere for optimization. + # Must be set at beginning of forward b/c not preserved by .cuda() etc + self.identity_map.isIdentity = True + + phi_AB = self.regis_net(image_A, image_B)["phi_AB"] + phi_BA = self.regis_net(image_B, image_A)["phi_AB"] + + phi_AB_vectorfield = phi_AB(self.identity_map) + phi_BA_vectorfield = phi_BA(self.identity_map) + + similarity_AB = self.compute_similarity(image_A, image_B, phi_AB_vectorfield) + similarity_BA = self.compute_similarity(image_B, image_A, phi_BA_vectorfield) + + similarity_loss = ( + similarity_AB["similarity_loss"] + similarity_BA["similarity_loss"] + ) + regularization_loss = self.compute_regularizer(phi_AB, phi_BA) + + all_loss = self.lmbda * regularization_loss + similarity_loss + + negative_jacobian_voxels = flips(phi_BA_vectorfield) + + return { + "all_loss": all_loss, + "regularization_loss": regularization_loss, + "similarity_loss": similarity_loss, + "phi_AB": phi_AB, + "phi_BA": phi_BA, + "warped_image_A": similarity_AB["warped_image_A"], + "warped_image_B": similarity_BA["warped_image_A"], + "negative_jacobian_voxels": negative_jacobian_voxels, + } + + +class ICON(TwoWayRegularizer): + def compute_regularizer(self, phi_AB, phi_BA): + Iepsilon = self.identity_map + torch.randn(*self.identity_map.shape).to( + self.identity_map.device + ) + + approximate_Iepsilon1 = phi_AB(phi_BA(Iepsilon)) + + approximate_Iepsilon2 = phi_BA(phi_AB(Iepsilon)) + + inverse_consistency_loss = torch.mean( + (Iepsilon - approximate_Iepsilon1) ** 2 + ) + torch.mean((Iepsilon - approximate_Iepsilon2) ** 2) + + inverse_consistency_loss /= self.input_shape[2] ** 2 + + return inverse_consistency_loss + + +class GradICON(TwoWayRegularizer): + def compute_regularizer(self, phi_AB, phi_BA): + Iepsilon = self.identity_map + torch.randn(*self.identity_map.shape).to( + self.identity_map.device + ) + if len(self.input_shape) - 2 == 3: + Iepsilon = Iepsilon[:, :, ::2, ::2, ::2] + elif len(self.input_shape) - 2 == 2: + Iepsilon = Iepsilon[:, :, ::2, ::2] + + # compute squared Frobenius of Jacobian of icon error + + direction_losses = [] + + approximate_Iepsilon = phi_AB(phi_BA(Iepsilon)) + + inverse_consistency_error = Iepsilon - approximate_Iepsilon + + delta = 0.5 + + if len(self.identity_map.shape) == 4: + dx = torch.Tensor([[[[delta]], [[0.0]]]]).to(self.identity_map.device) + dy = torch.Tensor([[[[0.0]], [[delta]]]]).to(self.identity_map.device) + direction_vectors = (dx, dy) + + elif len(self.identity_map.shape) == 5: + dx = torch.Tensor([[[[[delta]]], [[[0.0]]], [[[0.0]]]]]).to( + self.identity_map.device + ) + dy = torch.Tensor([[[[[0.0]]], [[[delta]]], [[[0.0]]]]]).to( + self.identity_map.device + ) + dz = torch.Tensor([[[[0.0]]], [[[0.0]]], [[[delta]]]]).to( + self.identity_map.device + ) + direction_vectors = (dx, dy, dz) + elif len(self.identity_map.shape) == 3: + dx = torch.Tensor([[[delta]]]).to(self.identity_map.device) + direction_vectors = (dx,) + + for d in direction_vectors: + approximate_Iepsilon_d = phi_AB(phi_BA(Iepsilon + d)) + inverse_consistency_error_d = Iepsilon + d - approximate_Iepsilon_d + grad_d_icon_error = ( + inverse_consistency_error - inverse_consistency_error_d + ) / delta + direction_losses.append(torch.mean(grad_d_icon_error**2)) + + gradient_inverse_consistency_loss = sum(direction_losses) + + return gradient_inverse_consistency_loss + + +class OneWayRegularizer(Loss): + def forward(self, image_A, image_B): + assert self.identity_map.shape[2:] == image_A.shape[2:] + assert self.identity_map.shape[2:] == image_B.shape[2:] + + # Tag used elsewhere for optimization. + # Must be set at beginning of forward b/c not preserved by .cuda() etc + self.identity_map.isIdentity = True + + phi_AB = self.regis_net(image_A, image_B)["phi_AB"] + + phi_AB_vectorfield = phi_AB(self.identity_map) + + similarity_AB = self.compute_similarity(image_A, image_B, phi_AB_vectorfield) + + similarity_loss = 2 * similarity_AB["similarity_loss"] + regularization_loss = self.compute_regularizer(phi_AB_vectorfield) + + all_loss = self.lmbda * regularization_loss + similarity_loss + + negative_jacobian_voxels = flips(phi_AB_vectorfield) + + return { + "all_loss": all_loss, + "regularization_loss": regularization_loss, + "similarity_loss": similarity_loss, + "phi_AB": phi_AB, + "warped_image_A": similarity_AB["warped_image_A"], + "negative_jacobian_voxels": negative_jacobian_voxels, + } + + +class BendingEnergy(OneWayRegularizer): + def compute_regularizer(self, phi_AB_vectorfield): + # dxdx = [f[x+h, y] + f[x-h, y] - 2 * f[x, y]]/(h**2) + # dxdy = [f[x+h, y+h] + f[x-h, y-h] - f[x+h, y-h] - f[x-h, y+h]]/(4*h**2) + # BE_2d = |dxdx| + |dydy| + 2 * |dxdy| + # pseudo code: BE_2d = [torch.mean(dxdx**2) + torch.mean(dydy**2) + 2 * torch.mean(dxdy**2)]/4.0 + # BE_3d = |dxdx| + |dydy| + |dzdz| + 2 * |dxdy| + 2 * |dydz| + 2 * |dxdz| + + if len(self.identity_map.shape) == 3: + dxdx = ( + phi_AB_vectorfield[:, :, 2:] + - 2 * phi_AB_vectorfield[:, :, 1:-1] + + phi_AB_vectorfield[:, :, :-2] + ) / self.spacing[0] ** 2 + bending_energy = torch.mean((dxdx) ** 2) + + elif len(self.identity_map.shape) == 4: + dxdx = ( + phi_AB_vectorfield[:, :, 2:] + - 2 * phi_AB_vectorfield[:, :, 1:-1] + + phi_AB_vectorfield[:, :, :-2] + ) / self.spacing[0] ** 2 + dydy = ( + phi_AB_vectorfield[:, :, :, 2:] + - 2 * phi_AB_vectorfield[:, :, :, 1:-1] + + phi_AB_vectorfield[:, :, :, :-2] + ) / self.spacing[1] ** 2 + dxdy = ( + phi_AB_vectorfield[:, :, 2:, 2:] + + phi_AB_vectorfield[:, :, :-2, :-2] + - phi_AB_vectorfield[:, :, 2:, :-2] + - phi_AB_vectorfield[:, :, :-2, 2:] + ) / (4.0 * self.spacing[0] * self.spacing[1]) + bending_energy = ( + torch.mean(dxdx**2) + + torch.mean(dydy**2) + + 2 * torch.mean(dxdy**2) + ) / 4.0 + elif len(self.identity_map.shape) == 5: + dxdx = ( + phi_AB_vectorfield[:, :, 2:] + - 2 * phi_AB_vectorfield[:, :, 1:-1] + + phi_AB_vectorfield[:, :, :-2] + ) / self.spacing[0] ** 2 + dydy = ( + phi_AB_vectorfield[:, :, :, 2:] + - 2 * phi_AB_vectorfield[:, :, :, 1:-1] + + phi_AB_vectorfield[:, :, :, :-2] + ) / self.spacing[1] ** 2 + dzdz = ( + phi_AB_vectorfield[:, :, :, :, 2:] + - 2 * phi_AB_vectorfield[:, :, :, :, 1:-1] + + phi_AB_vectorfield[:, :, :, :, :-2] + ) / self.spacing[2] ** 2 + dxdy = ( + phi_AB_vectorfield[:, :, 2:, 2:] + + phi_AB_vectorfield[:, :, :-2, :-2] + - phi_AB_vectorfield[:, :, 2:, :-2] + - phi_AB_vectorfield[:, :, :-2, 2:] + ) / (4.0 * self.spacing[0] * self.spacing[1]) + dydz = ( + phi_AB_vectorfield[:, :, :, 2:, 2:] + + phi_AB_vectorfield[:, :, :, :-2, :-2] + - phi_AB_vectorfield[:, :, :, 2:, :-2] + - phi_AB_vectorfield[:, :, :, :-2, 2:] + ) / (4.0 * self.spacing[1] * self.spacing[2]) + dxdz = ( + phi_AB_vectorfield[:, :, 2:, :, 2:] + + phi_AB_vectorfield[:, :, :-2, :, :-2] + - phi_AB_vectorfield[:, :, 2:, :, :-2] + - phi_AB_vectorfield[:, :, :-2, :, 2:] + ) / (4.0 * self.spacing[0] * self.spacing[2]) + + bending_energy = ( + (dxdx**2).mean() + + (dydy**2).mean() + + (dzdz**2).mean() + + 2.0 * (dxdy**2).mean() + + 2.0 * (dydz**2).mean() + + 2.0 * (dxdz**2).mean() + ) / 9.0 + + return bending_energy + + +class Diffusion(OneWayRegularizer): + def compute_regularizer(self, phi_AB_vectorfield): + phi_AB_vectorfield = self.identity_map - phi_AB_vectorfield + if len(self.identity_map.shape) == 3: + bending_energy = torch.mean( + (-phi_AB_vectorfield[:, :, 1:] + phi_AB_vectorfield[:, :, 1:-1]) ** 2 + ) + + elif len(self.identity_map.shape) == 4: + bending_energy = torch.mean( + (-phi_AB_vectorfield[:, :, 1:] + phi_AB_vectorfield[:, :, :-1]) ** 2 + ) + torch.mean( + (-phi_AB_vectorfield[:, :, :, 1:] + phi_AB_vectorfield[:, :, :, :-1]) + ** 2 + ) + elif len(self.identity_map.shape) == 5: + bending_energy = ( + torch.mean( + (-phi_AB_vectorfield[:, :, 1:] + phi_AB_vectorfield[:, :, :-1]) ** 2 + ) + + torch.mean( + ( + -phi_AB_vectorfield[:, :, :, 1:] + + phi_AB_vectorfield[:, :, :, :-1] + ) + ** 2 + ) + + torch.mean( + ( + -phi_AB_vectorfield[:, :, :, :, 1:] + + phi_AB_vectorfield[:, :, :, :, :-1] + ) + ** 2 + ) + ) + + return bending_energy * self.identity_map.shape[2] ** 2 + + +class VelocityFieldDiffusion(Diffusion): + def forward(self, image_A, image_B): + assert self.identity_map.shape[2:] == image_A.shape[2:] + assert self.identity_map.shape[2:] == image_B.shape[2:] + + # Tag used elsewhere for optimization. + # Must be set at beginning of forward b/c not preserved by .cuda() etc + self.identity_map.isIdentity = True + + phi_AB_dict = self.regis_net(image_A, image_B) + phi_AB = phi_AB_dict["phi_AB"] + + phi_AB_vectorfield = phi_AB(self.identity_map) + + similarity_AB = self.compute_similarity(image_A, image_B, phi_AB_vectorfield) + + similarity_loss = 2 * similarity_AB["similarity_loss"] + + velocity_fields = phi_AB["velocity_fields"] + regularization_loss = 0 + for v in velocity_fields: + regularization_loss += self.compute_regularizer(self, v + self.identity_map) + + all_loss = self.lmbda * regularization_loss + similarity_loss + + negative_jacobian_voxels = flips(phi_AB_vectorfield) + + return { + "all_loss": all_loss, + "regularization_loss": regularization_loss, + "similarity_loss": similarity_loss, + "phi_AB": phi_AB, + "warped_image_A": similarity_AB["warped_image_A"], + "negative_jacobian_voxels": negative_jacobian_voxels, + } + + +class VelocityFieldBendingEnergy(BendingEnergy): + def forward(self, image_A, image_B): + assert self.identity_map.shape[2:] == image_A.shape[2:] + assert self.identity_map.shape[2:] == image_B.shape[2:] + + # Tag used elsewhere for optimization. + # Must be set at beginning of forward b/c not preserved by .cuda() etc + self.identity_map.isIdentity = True + + phi_AB_dict = self.regis_net(image_A, image_B) + phi_AB = phi_AB_dict["phi_AB"] + + phi_AB_vectorfield = phi_AB(self.identity_map) + + similarity_AB = self.compute_similarity(image_A, image_B, phi_AB_vectorfield) + + similarity_loss = 2 * similarity_AB["similarity_loss"] + + velocity_fields = phi_AB["velocity_fields"] + regularization_loss = 0 + for v in velocity_fields: + regularization_loss += self.compute_regularizer(self, v) + + all_loss = self.lmbda * regularization_loss + similarity_loss + + negative_jacobian_voxels = flips(phi_AB_vectorfield) + + return { + "all_loss": all_loss, + "regularization_loss": regularization_loss, + "similarity_loss": similarity_loss, + "phi_AB": phi_AB, + "warped_image_A": similarity_AB["warped_image_A"], + "negative_jacobian_voxels": negative_jacobian_voxels, + } diff --git a/ICON/icon_registration/network_wrappers.py b/ICON/icon_registration/network_wrappers.py new file mode 100644 index 00000000..14c193e7 --- /dev/null +++ b/ICON/icon_registration/network_wrappers.py @@ -0,0 +1,268 @@ +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from icon_registration.registration_module import RegistrationModule + + +class DisplacementField(RegistrationModule): + """ + Wrap an inner neural network 'net' that returns a tensor of displacements + [B x N x H x W (x D)], into a RegistrationModule that returns a function that + transforms a tensor of coordinates + """ + + def __init__(self, net): + super().__init__() + self.net = net + + def forward(self, image_A, image_B): + concatenated_images = torch.cat([image_A, image_B], axis=1) + tensor_of_displacements = self.net(concatenated_images) + displacement_field = self.as_function(tensor_of_displacements) + + def transform(coordinates): + return coordinates + displacement_field(coordinates) + + return {"phi_AB": transform} + + +class VelocityField(RegistrationModule): + def __init__(self, net): + super().__init__() + self.net = net + self.n_steps = 256 + + def forward(self, image_A, image_B): + concatenated_images = torch.cat([image_A, image_B], axis=1) + velocity_field = self.net(concatenated_images) + velocityfield_delta = velocity_field / self.n_steps + + for _ in range(8): + velocityfield_delta = velocityfield_delta + self.as_function( + velocityfield_delta + )(velocityfield_delta + self.identity_map) + + def transform(coordinate_tensor): + coordinate_tensor = coordinate_tensor + self.as_function( + velocityfield_delta + )(coordinate_tensor) + return coordinate_tensor + + return {"phi_AB": transform, "velocity_fields": [velocity_field]} + + +def multiply_matrix_vectorfield(matrix, vectorfield): + dimension = len(vectorfield.shape) - 2 + if dimension == 2: + batch_matrix_multiply = "ijkl,imj->imkl" + else: + batch_matrix_multiply = "ijkln,imj->imkln" + return torch.einsum(batch_matrix_multiply, vectorfield, matrix) + + +class Affine(RegistrationModule): + """ + wrap an inner neural network `net` that returns an N x N+1 matrix representing + an affine transform, into a RegistrationModule that returns a function that + transforms a tensor of coordinates. + """ + + def __init__(self, net): + super().__init__() + self.net = net + + def forward(self, image_A, image_B): + matrix_phi = self.net(image_A, image_B) + + def transform(tensor_of_coordinates): + shape = list(tensor_of_coordinates.shape) + shape[1] = 1 + coordinates_homogeneous = torch.cat( + [ + tensor_of_coordinates, + torch.ones(shape, device=tensor_of_coordinates.device), + ], + axis=1, + ) + return multiply_matrix_vectorfield(matrix_phi, coordinates_homogeneous)[ + :, :-1 + ] + + return {"phi_AB": transform} + + +class TwoStep(RegistrationModule): + """Combine two RegistrationModules. + + First netPhi is called on the input images, then image_A is warped with + the resulting field, and then netPsi is called on warped A and image_B + in order to find a residual warping. Finally, the composition of the two + transforms is returned. + """ + + def __init__(self, netPhi, netPsi): + super().__init__() + self.netPhi = netPhi + self.netPsi = netPsi + + def forward(self, image_A, image_B): + phi = self.netPhi(image_A, image_B) + psi = self.netPsi(self.as_function(image_A)(phi["phi_AB"](self.identity_map)), image_B) + result = { + "phi_AB": lambda tensor_of_coordinates: phi["phi_AB"]( + psi["phi_AB"](tensor_of_coordinates) + ) + } + + regularization_loss = 0 + if "regularization_loss" in phi: + regularization_loss += phi["regularization_loss"] + if "regularization_loss" in psi: + regularization_loss += psi["regularization_loss"] + if "regularization_loss" in phi or "regularization_loss" in psi: + result["regularization_loss"] = regularization_loss + + velocity_fields = [] + if "velocity_fields" in phi: + velocity_fields += phi["regularization_loss"] + if "velocity_fields" in psi: + velocity_fields += psi["regularization_loss"] + if "velocity_fields" in phi or "regularization_loss" in psi: + result["velocity_fields"] = regularization_loss + return result + + +class Downsample(RegistrationModule): + """ + Perform registration using the wrapped RegistrationModule `net` + at half input resolution. + """ + + def __init__(self, net, dimension): + super().__init__() + self.net = net + if dimension == 2: + self.avg_pool = F.avg_pool2d + self.interpolate_mode = "bilinear" + else: + self.avg_pool = F.avg_pool3d + self.interpolate_mode = "trilinear" + self.dimension = dimension + # This member variable is read by assign_identity_map when + # walking the network tree and assigning identity_maps + # to know that all children of this module operate at a lower + # resolution. + self.downscale_factor = 2 + + def forward(self, image_A, image_B): + image_A = self.avg_pool(image_A, 2, ceil_mode=True) + image_B = self.avg_pool(image_B, 2, ceil_mode=True) + result = self.net(image_A, image_B) + + # MONAI's ddf coordinate convention depends on resolution: + + for key in ["phi_AB", "phi_BA"]: + if key in lowres_result: + highres_phi = lambda coords: 2 * result[key](coords / 2) + result[key] = highres_phi + + return result + + +class InverseConsistentVelocityField(RegistrationModule): + def __init__(self, net): + super().__init__() + self.net = net + self.n_steps = 7 + + def forward(self, image_A, image_B): + concatenated_images_AB = torch.cat([image_A, image_B], axis=1) + concatenated_images_BA = torch.cat([image_B, image_A], axis=1) + velocity_field = self.net(concatenated_images_AB) - self.net( + concatenated_images_BA + ) + velocityfield_delta_ab = velocity_field / 2**self.n_steps + velocityfield_delta_ba = -velocityfield_delta_ab + + for _ in range(self.n_steps): + velocityfield_delta_ab = velocityfield_delta_ab + self.as_function( + velocityfield_delta_ab + )(velocityfield_delta_ab + self.identity_map) + + def transform_AB(coordinate_tensor): + coordinate_tensor = coordinate_tensor + self.as_function( + velocityfield_delta_ab + )(coordinate_tensor) + return coordinate_tensor + + for _ in range(self.n_steps): + velocityfield_delta_ba = velocityfield_delta_ba + self.as_function( + velocityfield_delta_ba + )(velocityfield_delta_ba + self.identity_map) + + def transform_BA(coordinate_tensor): + coordinate_tensor = coordinate_tensor + self.as_function( + velocityfield_delta_ba + )(coordinate_tensor) + return coordinate_tensor + + return { + "phi_AB": transform_AB, + "phi_BA": transform_BA, + "velocity_fields": [velocity_field], + } + + +class InverseConsistentAffine(RegistrationModule): + """ + wrap an inner neural network `net` that returns an Batch x N*N+1 tensor representing + an affine transform, into a RegistrationModule that returns a function that + transforms a tensor of coordinates. + """ + + def __init__(self, net): + super().__init__() + self.net = net + + def forward(self, image_A, image_B): + concatenated_images_AB = torch.cat([image_A, image_B], axis=1) + concatenated_images_BA = torch.cat([image_B, image_A], axis=1) + matrix_phi = self.net(concatenated_images_AB) - self.net(concatenated_images_BA) + + matrix_phi = matrix_phi.reshape( + image_A.shape[0], len(image_A.shape), len(image_A.shape) + 1 + ) + + matrix_phi_AB = torch.linalg.matrix_exp(matrix_phi) + matrix_phi_BA = torch.linalg.matrix_exp(-matrix_phi) + + def transform_AB(tensor_of_coordinates): + shape = list(tensor_of_coordinates.shape) + shape[1] = 1 + coordinates_homogeneous = torch.cat( + [ + tensor_of_coordinates, + self.torch.ones(shape, device=tensor_of_coordinates.device), + ], + axis=1, + ) + return multiply_matrix_vectorfield(matrix_phi, coordinates_homogeneous)[ + :, :-1 + ] + + def transform_BA(tensor_of_coordinates): + shape = list(tensor_of_coordinates.shape) + shape[1] = 1 + coordinates_homogeneous = torch.cat( + [ + tensor_of_coordinates, + torch.ones(shape, device=tensor_of_coordinates.device), + ], + axis=1, + ) + return multiply_matrix_vectorfield(matrix_phi_BA, coordinates_homogeneous)[ + :, :-1 + ] + + return {"phi_AB": transform_AB, "phi_BA": transform_BA} diff --git a/ICON/icon_registration/networks.py b/ICON/icon_registration/networks.py new file mode 100644 index 00000000..e1008071 --- /dev/null +++ b/ICON/icon_registration/networks.py @@ -0,0 +1,380 @@ +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from icon_registration import config + + +class ConvNet(nn.Module): + def __init__(self, dimension=2, output_dim=100): + super().__init__() + self.dimension = dimension + + if dimension == 2: + self.Conv = nn.Conv2d + self.avg_pool = F.avg_pool2d + else: + self.Conv = nn.Conv3d + self.avg_pool = F.avg_pool3d + + self.features = [2, 16, 32, 64, 128, 128, 256] + self.convs = nn.ModuleList([]) + for depth in range(len(self.features) - 1): + self.convs.append( + self.Conv( + self.features[depth], + self.features[depth + 1], + kernel_size=3, + padding=1, + ) + ) + self.dense2 = nn.Linear(256, 300) + self.dense3 = nn.Linear(300, output_dim) + + def forward(self, x): + for depth in range(len(self.features) - 1): + x = F.relu(x) + x = self.convs[depth](x) + x = self.avg_pool(x, 2, ceil_mode=True) + x = self.avg_pool(x, x.shape[2:], ceil_mode=True) + x = torch.reshape(x, (-1, 256)) + x = F.relu(self.dense2(x)) + x = self.dense3(x) + return x + +def pad_or_crop(x, shape, dimension): + y = x[:, : shape[1]] + if x.size()[1] < shape[1]: + if dimension == 3: + y = F.pad(y, (0, 0, 0, 0, 0, 0, shape[1] - x.size()[1], 0)) + else: + y = F.pad(y, (0, 0, 0, 0, shape[1] - x.size()[1], 0)) + assert y.size()[1] == shape[1] + + return y + + +class UNet2(nn.Module): + def __init__(self, num_layers, channels, dimension): + super().__init__() + self.dimension = dimension + if dimension == 2: + self.BatchNorm = nn.BatchNorm2d + self.Conv = nn.Conv2d + self.ConvTranspose = nn.ConvTranspose2d + self.avg_pool = F.avg_pool2d + self.interpolate_mode = "bilinear" + else: + self.BatchNorm = nn.BatchNorm3d + self.Conv = nn.Conv3d + self.ConvTranspose = nn.ConvTranspose3d + self.avg_pool = F.avg_pool3d + self.interpolate_mode = "trilinear" + self.num_layers = num_layers + down_channels = np.array(channels[0]) + up_channels_out = np.array(channels[1]) + up_channels_in = down_channels[1:] + np.concatenate([up_channels_out[1:], [0]]) + self.downConvs = nn.ModuleList([]) + self.upConvs = nn.ModuleList([]) + self.batchNorms = nn.ModuleList( + [ + self.BatchNorm(num_features=up_channels_out[_]) + for _ in range(self.num_layers) + ] + ) + for depth in range(self.num_layers): + self.downConvs.append( + self.Conv( + down_channels[depth], + down_channels[depth + 1], + kernel_size=3, + padding=1, + stride=2, + ) + ) + self.upConvs.append( + self.ConvTranspose( + up_channels_in[depth], + up_channels_out[depth], + kernel_size=4, + padding=1, + stride=2, + ) + ) + self.lastConv = self.Conv( + down_channels[0] + up_channels_out[0], dimension, kernel_size=3, padding=1 + ) + torch.nn.init.zeros_(self.lastConv.weight) + torch.nn.init.zeros_(self.lastConv.bias) + + def forward(self, x): + skips = [] + for depth in range(self.num_layers): + skips.append(x) + y = self.downConvs[depth](F.leaky_relu(x)) + x = y + pad_or_crop( + self.avg_pool(x, 2, ceil_mode=True), y.size(), self.dimension + ) + + for depth in reversed(range(self.num_layers)): + y = self.upConvs[depth](F.leaky_relu(x)) + x = y + F.interpolate( + pad_or_crop(x, y.size(), self.dimension), + scale_factor=2, + mode=self.interpolate_mode, + align_corners=False, + ) + x = self.batchNorms[depth](x) + if self.dimension == 2: + x = x[:, :, : skips[depth].size()[2], : skips[depth].size()[3]] + else: + x = x[ + :, + :, + : skips[depth].size()[2], + : skips[depth].size()[3], + : skips[depth].size()[4], + ] + x = torch.cat([x, skips[depth]], 1) + x = self.lastConv(x) + return x * 10 + + +class UNetDenseMiddle(nn.Module): + def __init__(self, num_layers, channels, dimension): + super().__init__() + self.dimension = dimension + if dimension == 2: + self.BatchNorm = nn.BatchNorm2d + self.Conv = nn.Conv2d + self.ConvTranspose = nn.ConvTranspose2d + self.avg_pool = F.avg_pool2d + self.interpolate_mode = "bilinear" + else: + self.BatchNorm = nn.BatchNorm3d + self.Conv = nn.Conv3d + self.ConvTranspose = nn.ConvTranspose3d + self.avg_pool = F.avg_pool3d + self.interpolate_mode = "trilinear" + self.num_layers = num_layers + down_channels = np.array(channels[0]) + up_channels_out = np.array(channels[1]) + up_channels_in = down_channels[1:] + np.concatenate([up_channels_out[1:], [0]]) + self.downConvs = nn.ModuleList([]) + self.upConvs = nn.ModuleList([]) + self.batchNorms = nn.ModuleList( + [ + self.BatchNorm(num_features=up_channels_out[_]) + for _ in range(self.num_layers) + ] + ) + for depth in range(self.num_layers): + self.downConvs.append( + self.Conv( + down_channels[depth], + down_channels[depth + 1], + kernel_size=3, + padding=1, + stride=2, + ) + ) + self.upConvs.append( + self.ConvTranspose( + up_channels_in[depth], + up_channels_out[depth], + kernel_size=4, + padding=1, + stride=2, + ) + ) + self.lastConv = self.Conv(18, dimension, kernel_size=3, padding=1) + torch.nn.init.zeros_(self.lastConv.weight) + + self.middle_dense = nn.ModuleList( + [ + torch.nn.Linear(512 * 2 * 3 * 3, 128 * 2 * 3 * 3), + torch.nn.Linear(128 * 2 * 3 * 3, 512 * 2 * 3 * 3), + ] + ) + + def forward(self, x): + skips = [] + for depth in range(self.num_layers): + skips.append(x) + y = self.downConvs[depth](F.leaky_relu(x)) + x = y + pad_or_crop( + self.avg_pool(x, 2, ceil_mode=True), y.size(), self.dimension + ) + y = F.layer_norm + + x = torch.reshape(x, (-1, 512 * 2 * 3 * 3)) + x = self.middle_dense[1](F.leaky_relu(self.middle_dense[0](x))) + x = torch.reshape(x, (-1, 512, 2, 3, 3)) + for depth in reversed(range(self.num_layers)): + y = self.upConvs[depth](F.leaky_relu(x)) + x = y + F.interpolate( + pad_or_crop(x, y.size(), self.dimension), + scale_factor=2, + mode=self.interpolate_mode, + align_corners=False, + ) + x = self.batchNorms[depth](x) + + x = x[:, :, : skips[depth].size()[2], : skips[depth].size()[3]] + x = torch.cat([x, skips[depth]], 1) + x = self.lastConv(x) + return x / 10 + + + + +def tallUNet2(dimension=2, input_channels=1): + return UNet2( + 5, + [[input_channels * 2, 16, 32, 64, 256, 512], [16, 32, 64, 128, 256]], + dimension, + ) + + +class FCNet1D(nn.Module): + def __init__(self, size=28): + super().__init__() + self.size = size + self.dense1 = nn.Linear(size * 2, 8000) + self.dense2 = nn.Linear(8000, 3000) + self.dense3 = nn.Linear(3000, size) + torch.nn.init.zeros_(self.dense3.weight) + + def forward(self, x): + x = torch.reshape(x, (-1, 2 * self.size)) + x = F.relu(self.dense1(x)) + x = F.relu(self.dense2(x)) + x = self.dense3(x) + x = torch.reshape(x, (-1, 1, self.size)) + return x + + +class FCNet(nn.Module): + def __init__(self, size=28): + super().__init__() + self.size = size + self.dense1 = nn.Linear(size * size * 2, 8000) + self.dense2 = nn.Linear(8000, 3000) + self.dense3 = nn.Linear(3000, size * size * 2) + torch.nn.init.zeros_(self.dense3.weight) + + def forward(self, x): + x = torch.reshape(x, (-1, 2 * self.size * self.size)) + x = F.relu(self.dense1(x)) + x = F.relu(self.dense2(x)) + x = self.dense3(x) + x = torch.reshape(x, (-1, 2, self.size, self.size)) + return x + +class FCNet3D(nn.Module): + def __init__(self, shape, bottleneck=128): + super().__init__() + self.shape = shape.copy() + self.shape[1] = 3 + self.bottleneck = bottleneck + self.dense1 = nn.Linear(2 * np.product(self.shape[2:]), self.bottleneck) + self.dense2 = nn.Linear(self.bottleneck, 8000) + self.dense3 = nn.Linear(8000, self.bottleneck) + self.dense4 = nn.Linear(self.bottleneck, np.product(self.shape[1:])) + torch.nn.init.zeros_(self.dense4.weight) + torch.nn.init.zeros_(self.dense4.bias) + + def forward(self, x): + x = torch.reshape(x, (-1, 2 * np.product(self.shape[2:]))) + x = F.relu(self.dense1(x)) + x = F.relu(self.dense2(x)) + x = F.relu(self.dense3(x)) + x = self.dense4(x) + x = torch.reshape(x, tuple(self.shape)) + return x + + +class ConvolutionalMatrixNet(nn.Module): + def __init__(self, dimension=2): + super().__init__() + self.dimension = dimension + + if dimension == 2: + self.Conv = nn.Conv2d + self.avg_pool = F.avg_pool2d + else: + self.Conv = nn.Conv3d + self.avg_pool = F.avg_pool3d + + self.features = [2, 16, 32, 64, 128, 256, 512] + self.convs = nn.ModuleList([]) + for depth in range(len(self.features) - 1): + self.convs.append( + self.Conv( + self.features[depth], + self.features[depth + 1], + kernel_size=3, + padding=1, + ) + ) + self.dense2 = nn.Linear(512, 300) + self.dense3 = nn.Linear(300, 6 if self.dimension == 2 else 12) + torch.nn.init.zeros_(self.dense3.weight) + torch.nn.init.zeros_(self.dense3.bias) + + def forward(self, x): + for depth in range(len(self.features) - 1): + x = F.relu(x) + x = self.convs[depth](x) + x = self.avg_pool(x, 2, ceil_mode=True) + x = self.avg_pool(x, x.shape[2:], ceil_mode=True) + x = torch.reshape(x, (-1, 512)) + x = F.relu(self.dense2(x)) + x = self.dense3(x) + if self.dimension == 3: + x = torch.reshape(x, (-1, 3, 4)) + x = torch.cat( + [ + x, + torch.Tensor([[[0, 0, 0, 1]]]) + .to(x.device) + .expand(x.shape[0], -1, -1), + ], + 1, + ) + x = x + torch.Tensor( + [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 0]] + ).to(x.device) + x = torch.matmul( + torch.Tensor( + [[1, 0, 0, 0.5], [0, 1, 0, 0.5], [0, 0, 1, 0.5], [0, 0, 0, 1]] + ).to(x.device), + x, + ) + x = torch.matmul( + x, + torch.Tensor( + [[1, 0, 0, -0.5], [0, 1, 0, -0.5], [0, 0, 1, -0.5], [0, 0, 0, 1]] + ).to(x.device), + ) + elif self.dimension == 2: + x = torch.reshape(x, (-1, 2, 3)) + x = torch.cat( + [ + x, + torch.Tensor([[[0, 0, 1]]]).to(x.device).expand(x.shape[0], -1, -1), + ], + 1, + ) + x = x + torch.Tensor([[1, 0, 0], [0, 1, 0], [0, 0, 0]]).to(x.device) + x = torch.matmul( + torch.Tensor([[1, 0, 0.5], [0, 1, 0.5], [0, 0, 1]]).to(x.device), x + ) + x = torch.matmul( + x, + torch.Tensor([[1, 0, -0.5], [0, 1, -0.5], [0, 0, 1]]).to(x.device), + ) + else: + raise ArgumentError() + return x diff --git a/ICON/icon_registration/pretrained_models/HCP_brain.py b/ICON/icon_registration/pretrained_models/HCP_brain.py new file mode 100644 index 00000000..885e77c8 --- /dev/null +++ b/ICON/icon_registration/pretrained_models/HCP_brain.py @@ -0,0 +1,19 @@ +import itk +from .lung_ct import init_network + + +def brain_network_preprocess(image: "itk.Image") -> "itk.Image": + if type(image) == itk.Image[itk.SS, 3]: + cast_filter = itk.CastImageFilter[ + itk.Image[itk.SS, 3], itk.Image[itk.F, 3] + ].New() + cast_filter.SetInput(image) + cast_filter.Update() + image = cast_filter.GetOutput() + _, max_ = itk.image_intensity_min_max(image) + image = itk.shift_scale_image_filter(image, shift=0.0, scale=0.9 / max_) + return image + + +def brain_registration_model(pretrained=True): + return init_network("brain", pretrained=pretrained) diff --git a/ICON/icon_registration/pretrained_models/OAI_knees.py b/ICON/icon_registration/pretrained_models/OAI_knees.py new file mode 100644 index 00000000..94fb224f --- /dev/null +++ b/ICON/icon_registration/pretrained_models/OAI_knees.py @@ -0,0 +1,67 @@ +from os.path import exists + +import torch + +import icon_registration +from icon_registration import config + +from .. import networks +from .lung_ct import init_network +from ..similarity import SSD + + +def OAI_knees_registration_model(pretrained=True): + # The definition of our final 4 step registration network. + + phi = icon_registration.DisplacementField( + networks.tallUNet(unet=networks.UNet2ChunkyMiddle, dimension=3) + ) + psi = icon_registration.DisplacementField(networks.tallUNet2(dimension=3)) + + pretrained_lowres_net = icon_registration.TwoStepRegistration(phi, psi) + + hires_net = icon_registration.TwoStepRegistration( + icon_registration.DownsampleRegistration(pretrained_lowres_net, dimension=3), + icon_registration.DisplacementField(networks.tallUNet2(dimension=3)), + ) + + fourth_net = icon_registration.ICON( + icon_registration.TwoStepRegistration( + hires_net, + icon_registration.DisplacementField(networks.tallUNet2(dimension=3)), + ), + SSD(), + 3600, + ) + + BATCH_SIZE = 3 + SCALE = 2 # 1 IS QUARTER RES, 2 IS HALF RES, 4 IS FULL RES + input_shape = [BATCH_SIZE, 1, 40 * SCALE, 96 * SCALE, 96 * SCALE] + + if pretrained: + weights_location = "network_weights/pretrained_OAI_model_compact" + if not exists(weights_location): + print("Downloading pretrained model (500 Mb)") + import urllib.request + import os + + os.makedirs("network_weights", exist_ok=True) + + urllib.request.urlretrieve( + "https://github.com/uncbiag/ICON/releases/download/pretrained_oai_model/OAI_knees_ICON_model.pth", + weights_location, + ) + + trained_weights = torch.load(weights_location, map_location=torch.device("cpu")) + fourth_net.load_state_dict(trained_weights) + + fourth_net.assign_identity_map(input_shape) + + net = fourth_net + net.to(config.device) + net.eval() + return net + + +def OAI_knees_gradICON_model(pretrained=True): + return init_network("knee", pretrained) diff --git a/ICON/icon_registration/pretrained_models/__init__.py b/ICON/icon_registration/pretrained_models/__init__.py new file mode 100644 index 00000000..48d9313b --- /dev/null +++ b/ICON/icon_registration/pretrained_models/__init__.py @@ -0,0 +1,3 @@ +from .OAI_knees import OAI_knees_gradICON_model, OAI_knees_registration_model +from .lung_ct import LungCT_registration_model, lung_network_preprocess +from .HCP_brain import brain_registration_model, brain_network_preprocess diff --git a/ICON/icon_registration/pretrained_models/lung_ct.py b/ICON/icon_registration/pretrained_models/lung_ct.py new file mode 100644 index 00000000..4abc9930 --- /dev/null +++ b/ICON/icon_registration/pretrained_models/lung_ct.py @@ -0,0 +1,106 @@ +import itk +import torch + +import icon_registration.config as config + +from .. import losses, network_wrappers, networks + + +def make_network(): + dimension = 3 + inner_net = network_wrappers.DisplacementField( + networks.tallUNet2(dimension=dimension) + ) + + for _ in range(2): + inner_net = network_wrappers.TwoStepRegistration( + network_wrappers.DownsampleRegistration(inner_net, dimension=dimension), + network_wrappers.DisplacementField(networks.tallUNet2(dimension=dimension)), + ) + inner_net = network_wrappers.TwoStepRegistration( + inner_net, + network_wrappers.DisplacementField(networks.tallUNet2(dimension=dimension)), + ) + + net = losses.GradICON(inner_net, similarity=losses.LNCC(sigma=5), lmbda=1.5) + + return net + + +def init_network(task, pretrained=True): + if task == "lung": + input_shape = [1, 1, 175, 175, 175] + elif task == "knee": + input_shape = [1, 1, 80, 192, 192] + elif task == "brain": + input_shape = [1, 1, 130, 155, 130] + else: + print(f"Task {task} is not defined. Fall back to the lung model.") + task = "lung" + input_shape = [1, 1, 175, 175, 175] + + net = make_network() + net.assign_identity_map(input_shape) + + if pretrained: + from os.path import exists + + weights_location = f"network_weights/{task}_model" + + if not exists(f"{weights_location}/{task}_model_weights.trch"): + print("Downloading pretrained model") + import urllib.request + import os + + download_path = "https://github.com/uncbiag/ICON/releases/download" + download_path = f"{download_path}/pretrained_models_v1.0.0" + + os.makedirs(weights_location, exist_ok=True) + urllib.request.urlretrieve( + f"{download_path}/{task}_model_weights_step_2.trch", + f"{weights_location}/{task}_model_weights.trch", + ) + + trained_weights = torch.load( + f"{weights_location}/{task}_model_weights.trch", + map_location=torch.device("cpu"), + ) + net.regis_net.load_state_dict(trained_weights, strict=False) + net.assign_identity_map(input_shape) + + net.to(config.device) + net.eval() + return net + + +def lung_network_preprocess( + image: "itk.Image", segmentation: "itk.Image" +) -> "itk.Image": + image = itk.clamp_image_filter(image, Bounds=(-1000, 0)) + cast_filter = itk.CastImageFilter[type(image), itk.Image.F3].New() + cast_filter.SetInput(image) + cast_filter.Update() + image = cast_filter.GetOutput() + + segmentation_cast_filter = itk.CastImageFilter[ + type(segmentation), itk.Image.F3 + ].New() + segmentation_cast_filter.SetInput(segmentation) + segmentation_cast_filter.Update() + segmentation = segmentation_cast_filter.GetOutput() + + image = itk.shift_scale_image_filter(image, shift=1000, scale=1 / 1000) + + mask_filter = itk.MultiplyImageFilter[ + itk.Image.F3, itk.Image.F3, itk.Image.F3 + ].New() + + mask_filter.SetInput1(image) + mask_filter.SetInput2(segmentation) + mask_filter.Update() + + return mask_filter.GetOutput() + + +def LungCT_registration_model(pretrained=True): + return init_network("lung", pretrained=pretrained) diff --git a/ICON/icon_registration/registration_module.py b/ICON/icon_registration/registration_module.py new file mode 100644 index 00000000..68835798 --- /dev/null +++ b/ICON/icon_registration/registration_module.py @@ -0,0 +1,156 @@ +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from monai.networks.blocks import Warp +from monai.networks.utils import meshgrid_ij + +class CoordinateWarp(Warp): + def forward(self, image: torch.tensor, ddf: torch.tensor): + """ + args: + image: tensor in shape (batch, num_channels, h, w[, d]) + ddf: tensor in the same spatial size as image, in shape (batch, ``spatial_dims``, h, w[, d]) + + returns: + warped_image in the same shape as image (batch, num_channels, h, w[, d]) + """ + spatial_dims = len(image.shape) - 2 + if spatial_dims not in (2, 3): + raise notimplementederror(f"got unsupported spatial_dims={spatial_dims}, currently support 2 or 3.") + ddf_shape = (image.shape[0], spatial_dims) + tuple(image.shape[2:]) + grid = ddf # assume + grid = grid.permute([0] + list(range(2, 2 + spatial_dims)) + [1]) # (batch, ..., spatial_dims) + + for i, dim in enumerate(grid.shape[1:-1]): + grid[..., i] = grid[..., i] * 2 / (dim - 1) - 1 + index_ordering: list[int] = list(range(spatial_dims - 1, -1, -1)) + grid = grid[..., index_ordering] # z, y, x -> x, y, z + return F.grid_sample( + image, grid, mode=self._interp_mode, padding_mode=f"{self._padding_mode}", align_corners=True + ) + + + + + +class RegistrationModule(nn.Module): + r"""Base class for icon modules that perform registration. + + A subclass of RegistrationModule should have a forward method that + takes as input two images image_A and image_B, and returns a python function + phi_AB that transforms a tensor of coordinates. + + RegistrationModule provides a method as_function that turns a tensor + representing an image into a python function mapping a tensor of coordinates + into a tensor of intensities :math:`\mathbb{R}^N \rightarrow \mathbb{R}` . + Mathematically, this is what an image is anyway. + + After this class is constructed, but before it is used, you _must_ call + assign_identity_map on it or on one of its parents to define the coordinate + system associated with input images. + + The contract that a successful registration fulfils is: + for a tensor of coordinates X, self.as_function(image_A)(phi_AB(X)) ~= self.as_function(image_B)(X) + + ie + + .. math:: + I^A \circ \Phi^{AB} \simeq I^B + + In particular, self.as_function(image_A)(phi_AB(self.identity_map)) ~= image_B + """ + + def __init__(self): + super().__init__() + self.downscale_factor = 1 + self.warp = CoordinateWarp() + self.identity_map = None + + def _make_identity_map(self, shape): + mesh_points = [torch.arange(0, dim) for dim in shape[2:]] + grid = torch.stack(meshgrid_ij(*mesh_points), dim=0) # (spatial_dims, ...) + grid = torch.stack([grid], dim=0).float() # (batch, spatial_dims, ...) + return grid + + def as_function(self, image): + """image is a (potentially vector valued) tensor with shape self.input_shape. + Returns a python function that maps a tensor of coordinates [batch x N_dimensions x ...] + into a tensor of the intensity of `image` at `coordinates`. + + This allows translating the standard notation of registration papers more literally into code. + + I \\circ \\Phi , the standard mathematical notation for a warped image, has the type + "function from coordinates to intensities" and can be translated to the python code + + warped_image = lambda coords: self.as_function(I)(phi(coords)) + + Often, this should actually be left as a function. If a tensor is needed, conversion is: + + warped_image_tensor = warped_image(self.identity_map) + """ + + def partially_applied_warp(coordinates): + coordinates_shape = list(coordinates.shape) + coordinates_shape[0] = image.shape[0] + coordinates = torch.broadcast_to(coordinates, coordinates_shape) + return self.warp(image, coordinates.clone()) + + return partially_applied_warp + + def assign_identity_map(self, input_shape, parents_identity_map=None): + self.input_shape = input_shape + grid = self._make_identity_map(input_shape) + del self.identity_map + self.register_buffer("identity_map", grid, persistent=False) + + if self.downscale_factor != 1: + child_shape = np.concatenate( + [ + self.input_shape[:2], + np.ceil(self.input_shape[2:] / self.downscale_factor).astype(int), + ] + ) + else: + child_shape = self.input_shape + for child in self.children(): + if isinstance(child, RegistrationModule): + child.assign_identity_map( + child_shape, + # None if self.downscale_factor != 1 else self.identity_map, + ) + + def make_ddf_from_icon_transform(self, transform): + """Compute A deformation field compatible with monai's Warp + using an ICON transform. The assosciated ICON identity_map is also required + """ + return transform(self.identity_map) - self.identity_map + + def make_ddf_using_icon_module(self, image_A, image_B): + """Compute a deformation field compatible with monai's Warp + using an ICON RegistrationModule. If the RegistrationModule returns a transform, this function + returns the monai version of that transform. If the RegistrationModule returns a loss, + this function returns a monai version of the internal transform as well as the loss. + """ + + res = self(image_A, image_B) + field = self.make_ddf_from_icon_transform(res["phi_AB"]) + return field, res + + def forward(image_A, image_B): + """Register a pair of images: + return a python function phi_AB that warps a tensor of coordinates such that + + .. code-block:: python + + self.as_function(image_A)(phi_AB(self.identity_map)) ~= image_B + + .. math:: + I^A \circ \Phi^{AB} \simeq I^B + + :param image_A: the moving image + :param image_B: the fixed image + :return: :math:`\Phi^{AB}` + """ + raise NotImplementedError() diff --git a/ICON/icon_registration/similarity.py b/ICON/icon_registration/similarity.py new file mode 100644 index 00000000..e49c9c12 --- /dev/null +++ b/ICON/icon_registration/similarity.py @@ -0,0 +1,281 @@ +import torch + +def normalize(image): + dimension = len(image.shape) - 2 + if dimension == 2: + dim_reduce = [2, 3] + elif dimension == 3: + dim_reduce = [2, 3, 4] + image_centered = image - torch.mean(image, dim_reduce, keepdim=True) + stddev = torch.sqrt(torch.mean(image_centered**2, dim_reduce, keepdim=True)) + return image_centered / stddev + + +class SimilarityBase: + def __init__(self, isInterpolated=False): + self.isInterpolated = isInterpolated + + +class NCC(SimilarityBase): + def __init__(self): + super().__init__(isInterpolated=False) + + def __call__(self, image_A, image_B): + assert ( + image_A.shape == image_B.shape + ), "The shape of image_A and image_B sould be the same." + A = normalize(image_A) + B = normalize(image_B) + res = torch.mean(A * B) + return 1 - res + + +# torch removed this function from torchvision.functional_tensor, so we are vendoring it. +def _get_gaussian_kernel1d(kernel_size, sigma): + ksize_half = (kernel_size - 1) * 0.5 + x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) + pdf = torch.exp(-0.5 * (x / sigma).pow(2)) + kernel1d = pdf / pdf.sum() + return kernel1d + + +def gaussian_blur(tensor, kernel_size, sigma, padding="same"): + kernel1d = _get_gaussian_kernel1d(kernel_size=kernel_size, sigma=sigma).to( + tensor.device, dtype=tensor.dtype + ) + out = tensor + group = tensor.shape[1] + + if len(tensor.shape) - 2 == 1: + out = torch.conv1d( + out, + kernel1d[None, None, :].expand(group, -1, -1), + padding="same", + groups=group, + ) + elif len(tensor.shape) - 2 == 2: + out = torch.conv2d( + out, + kernel1d[None, None, :, None].expand(group, -1, -1, -1), + padding="same", + groups=group, + ) + out = torch.conv2d( + out, + kernel1d[None, None, None, :].expand(group, -1, -1, -1), + padding="same", + groups=group, + ) + elif len(tensor.shape) - 2 == 3: + out = torch.conv3d( + out, + kernel1d[None, None, :, None, None].expand(group, -1, -1, -1, -1), + padding="same", + groups=group, + ) + out = torch.conv3d( + out, + kernel1d[None, None, None, :, None].expand(group, -1, -1, -1, -1), + padding="same", + groups=group, + ) + out = torch.conv3d( + out, + kernel1d[None, None, None, None, :].expand(group, -1, -1, -1, -1), + padding="same", + groups=group, + ) + + return out + + +class LNCC(SimilarityBase): + def __init__(self, sigma): + super().__init__(isInterpolated=False) + self.sigma = sigma + + def blur(self, tensor): + return gaussian_blur(tensor, self.sigma * 4 + 1, self.sigma) + + def __call__(self, image_A, image_B): + I = image_A + J = image_B + assert I.shape == J.shape, "The shape of image I and J sould be the same." + + return torch.mean( + 1 + - (self.blur(I * J) - (self.blur(I) * self.blur(J))) + / torch.sqrt( + (self.blur(I * I) - self.blur(I) ** 2 + 0.00001) + * (self.blur(J * J) - self.blur(J) ** 2 + 0.00001) + ) + ) + + +class LNCCOnlyInterpolated(SimilarityBase): + def __init__(self, sigma): + super().__init__(isInterpolated=True) + self.sigma = sigma + + def blur(self, tensor): + return gaussian_blur(tensor, self.sigma * 4 + 1, self.sigma) + + def __call__(self, image_A, image_B): + I = image_A[:, :-1] + J = image_B + + assert I.shape == J.shape, "The shape of image I and J sould be the same." + lncc_everywhere = 1 - ( + self.blur(I * J) - (self.blur(I) * self.blur(J)) + ) / torch.sqrt( + (self.blur(I * I) - self.blur(I) ** 2 + 0.00001) + * (self.blur(J * J) - self.blur(J) ** 2 + 0.00001) + ) + + with torch.no_grad(): + A_inbounds = image_A[:, -1:] + + inbounds_mask = self.blur(A_inbounds) > 0.999 + + if len(image_A.shape) - 2 == 3: + dimensions_to_sum_over = [2, 3, 4] + elif len(image_A.shape) - 2 == 2: + dimensions_to_sum_over = [2, 3] + elif len(image_A.shape) - 2 == 1: + dimensions_to_sum_over = [2] + + lncc_loss = torch.sum( + inbounds_mask * lncc_everywhere, dimensions_to_sum_over + ) / torch.sum(inbounds_mask, dimensions_to_sum_over) + + return torch.mean(lncc_loss) + + +class BlurredSSD(SimilarityBase): + def __init__(self, sigma): + super().__init__(isInterpolated=False) + self.sigma = sigma + + def blur(self, tensor): + return gaussian_blur(tensor, self.sigma * 4 + 1, self.sigma) + + def __call__(self, image_A, image_B): + assert ( + image_A.shape == image_B.shape + ), "The shape of image_A and image_B sould be the same." + return torch.mean((self.blur(image_A) - self.blur(image_B)) ** 2) + + +class AdaptiveNCC(SimilarityBase): + def __init__(self, level=4, threshold=0.1, gamma=1.5, sigma=2): + super().__init__(isInterpolated=False) + self.level = level + self.threshold = threshold + self.gamma = gamma + self.sigma = sigma + + def blur(self, tensor): + return gaussian_blur(tensor, self.sigma * 2 + 1, self.sigma) + + def __call__(self, image_A, image_B): + assert ( + image_A.shape == image_B.shape + ), "The shape of image_A and image_B sould be the same." + + def _nccBeforeMean(image_A, image_B): + A = normalize(image_A) + B = normalize(image_B) + res = torch.mean(A * B, dim=(1, 2, 3, 4)) + return 1 - res + + sims = [_nccBeforeMean(image_A, image_B)] + for i in range(self.level): + if i == 0: + sims.append(_nccBeforeMean(self.blur(image_A), self.blur(image_B))) + else: + sims.append( + _nccBeforeMean( + self.blur(F.avg_pool3d(image_A, 2**i)), + self.blur(F.avg_pool3d(image_B, 2**i)), + ) + ) + + sim_loss = sims[0] + 0 + lamb_ = 1.0 + for i in range(1, len(sims)): + lamb = torch.clamp( + sims[i].detach() / (self.threshold / (self.gamma ** (len(sims) - i))), + 0, + 1, + ) + sim_loss = lamb * sims[i] + (1 - lamb) * sim_loss + lamb_ *= 1 - lamb + + return torch.mean(sim_loss) + + +class SSD(SimilarityBase): + def __init__(self): + super().__init__(isInterpolated=False) + + def __call__(self, image_A, image_B): + assert ( + image_A.shape == image_B.shape + ), "The shape of image_A and image_B sould be the same." + return torch.mean((image_A - image_B) ** 2) + + +class SSDOnlyInterpolated(SimilarityBase): + def __init__(self): + super().__init__(isInterpolated=True) + + def __call__(self, image_A, image_B): + if len(image_A.shape) - 2 == 3: + dimensions_to_sum_over = [2, 3, 4] + elif len(image_A.shape) - 2 == 2: + dimensions_to_sum_over = [2, 3] + elif len(image_A.shape) - 2 == 1: + dimensions_to_sum_over = [2] + + inbounds_mask = image_A[:, -1:] + image_A = image_A[:, :-1] + assert ( + image_A.shape == image_B.shape + ), "The shape of image_A and image_B sould be the same." + + inbounds_squared_distance = inbounds_mask * (image_A - image_B) ** 2 + sum_squared_distance = torch.sum( + inbounds_squared_distance, dimensions_to_sum_over + ) + divisor = torch.sum(inbounds_mask, dimensions_to_sum_over) + ssds = sum_squared_distance / divisor + return torch.mean(ssds) + + +def flips(phi, in_percentage=False): + if len(phi.size()) == 5: + a = (phi[:, :, 1:, 1:, 1:] - phi[:, :, :-1, 1:, 1:]).detach() + b = (phi[:, :, 1:, 1:, 1:] - phi[:, :, 1:, :-1, 1:]).detach() + c = (phi[:, :, 1:, 1:, 1:] - phi[:, :, 1:, 1:, :-1]).detach() + + dV = torch.sum(torch.cross(a, b, 1) * c, axis=1, keepdims=True) + if in_percentage: + return torch.mean((dV < 0).float()) * 100.0 + else: + return torch.sum(dV < 0) / phi.shape[0] + elif len(phi.size()) == 4: + du = (phi[:, :, 1:, :-1] - phi[:, :, :-1, :-1]).detach() + dv = (phi[:, :, :-1, 1:] - phi[:, :, :-1, :-1]).detach() + dA = du[:, 0] * dv[:, 1] - du[:, 1] * dv[:, 0] + if in_percentage: + return torch.mean((dA < 0).float()) * 100.0 + else: + return torch.sum(dA < 0) / phi.shape[0] + elif len(phi.size()) == 3: + du = (phi[:, :, 1:] - phi[:, :, :-1]).detach() + if in_percentage: + return torch.mean((du < 0).float()) * 100.0 + else: + return torch.sum(du < 0) / phi.shape[0] + else: + raise ValueError() diff --git a/ICON/icon_registration/test_utils.py b/ICON/icon_registration/test_utils.py new file mode 100644 index 00000000..16eec5e5 --- /dev/null +++ b/ICON/icon_registration/test_utils.py @@ -0,0 +1,64 @@ +import pathlib +import subprocess +import sys +import numpy as np + +TEST_DATA_DIR = pathlib.Path(__file__).parent.parent.parent / "test_files" + + +def download_test_data(): + subprocess.run( + [ + "girder-client", + "--api-url", + "https://data.kitware.com/api/v1", + "localsync", + "61d3a99d4acac99f429277d7", + str(TEST_DATA_DIR), + ], + # stdout=sys.stdout, + ) + + +COPD_spacing = { + "copd1": [0.625, 0.625, 2.5], + "copd2": [0.645, 0.645, 2.5], + "copd3": [0.652, 0.652, 2.5], + "copd4": [0.590, 0.590, 2.5], + "copd5": [0.647, 0.647, 2.5], + "copd6": [0.633, 0.633, 2.5], + "copd7": [0.625, 0.625, 2.5], + "copd8": [0.586, 0.586, 2.5], + "copd9": [0.664, 0.664, 2.5], + "copd10": [0.742, 0.742, 2.5], +} + + +def read_copd_pointset(f_path): + """Points are deliminated by '\n' and X,Y,Z of each point are deliminated by '\t'. + + :param f_path: the path to the file containing + the position of points from copdgene dataset. + :return: numpy array of points in physical coordinates + """ + spacing = COPD_spacing[f_path.split("/")[-1].split("_")[0]] + spacing = np.expand_dims(spacing, 0) + with open(f_path) as fp: + content = fp.read().split("\n") + + # Read number of points from second + count = len(content) - 1 + + # Read the points + points = np.ndarray([count, 3], dtype=np.float64) + for i in range(count): + if content[i] == "": + break + temp = content[i].split("\t") + points[i, 0] = float(temp[0]) + points[i, 1] = float(temp[1]) + points[i, 2] = float(temp[2]) + + # The copd gene points are in index space instead of physical space. + # Move them to physical space. + return (points - 1) * spacing diff --git a/ICON/icon_registration/train.py b/ICON/icon_registration/train.py new file mode 100644 index 00000000..721abcab --- /dev/null +++ b/ICON/icon_registration/train.py @@ -0,0 +1,136 @@ +from datetime import datetime + +import torch +import tqdm + +from .losses import to_floats +import icon_registration.config + + +def write_stats(writer, stats, ite): + for k, v in to_floats(stats)._asdict().items(): + writer.add_scalar(k, v, ite) + + +def train_batchfunction( + net, + optimizer, + make_batch, + steps=100000, + step_callback=(lambda net: None), + unwrapped_net=None, +): + """A training function intended for long running experiments, with tensorboard logging + and model checkpoints. Use for medical registration training + """ + import footsteps + from torch.utils.tensorboard import SummaryWriter + + if unwrapped_net is None: + unwrapped_net = net + + loss_curve = [] + writer = SummaryWriter( + footsteps.output_dir + "/" + datetime.now().strftime("%Y%m%d-%H%M%S"), + flush_secs=30, + ) + + visualization_moving, visualization_fixed = [m[:4] for m in make_batch()] + for iteration in range(0, steps): + optimizer.zero_grad() + moving_image, fixed_image = make_batch() + loss_object = net(moving_image, fixed_image) + loss = torch.mean(loss_object.all_loss) + loss.backward() + + step_callback(unwrapped_net) + + print(to_floats(loss_object)) + write_stats(writer, loss_object, iteration) + optimizer.step() + + if iteration % 300 == 0: + torch.save( + optimizer.state_dict(), + footsteps.output_dir + "optimizer_weights_" + str(iteration), + ) + torch.save( + unwrapped_net.regis_net.state_dict(), + footsteps.output_dir + "network_weights_" + str(iteration), + ) + unwrapped_net.eval() + print("val (from train set)") + warped = [] + with torch.no_grad(): + for i in range(4): + print( + unwrapped_net( + visualization_moving[i : i + 1], + visualization_fixed[i : i + 1], + ) + ) + warped.append(unwrapped_net.warped_image_A.cpu()) + warped = torch.cat(warped) + unwrapped_net.train() + + def render(im): + if len(im.shape) == 5: + im = im[:, :, :, :, im.shape[4] // 2] + if torch.min(im) < 0: + im = im - torch.min(im) + if torch.max(im) > 1: + im = im / torch.max(im) + return im[:4, [0, 0, 0]].detach().cpu() + + writer.add_images( + "moving_image", + render(visualization_moving[:4]), + iteration, + dataformats="NCHW", + ) + writer.add_images( + "fixed_image", + render(visualization_fixed[:4]), + iteration, + dataformats="NCHW", + ) + writer.add_images( + "warped_moving_image", + render(warped), + iteration, + dataformats="NCHW", + ) + writer.add_images( + "difference", + render( + torch.clip( + (warped[:4, :1] - visualization_fixed[:4, :1].cpu()) + 0.5, 0, 1 + ) + ), + iteration, + dataformats="NCHW", + ) + + +def train_datasets(net, optimizer, d1, d2, epochs=400): + """A training function for quick experiments""" + batch_size = net.identity_map.shape[0] + loss_history = [] + for epoch in tqdm.tqdm(range(epochs)): + for A, B in list(zip(d1, d2)): + if True: # A[0].size()[0] == batch_size: + image_A = A[0].to(icon_registration.config.device) + image_B = B[0].to(icon_registration.config.device) + optimizer.zero_grad() + + loss_object = net(image_A, image_B) + + loss_object["all_loss"].backward() + optimizer.step() + + loss_history.append(to_floats(loss_object)) + print(to_floats(loss_object)) + return loss_history + + +train2d = train_datasets diff --git a/ICON/icon_registration/visualize.py b/ICON/icon_registration/visualize.py new file mode 100644 index 00000000..fbfa3d12 --- /dev/null +++ b/ICON/icon_registration/visualize.py @@ -0,0 +1,42 @@ +import torch +import torchvision +import matplotlib.pyplot as plt +import footsteps + +def show(tensor): + plt.imshow(torchvision.utils.make_grid(tensor[:6], nrow=3)[0].cpu().detach()) + plt.xticks([]) + plt.yticks([]) + +def render(im): + if len(im.shape) == 5: + im = im[:, :, :, :, im.shape[4] // 2] + if torch.min(im) < 0: + im = im - torch.min(im) + if torch.max(im) > 1: + im = im / torch.max(im) + return im[:4, [0, 0, 0]].detach().cpu() + + +def plot_registration_result(image_A, image_B, net) : + + res = net(image_A, image_B) + print(res.keys()) + + + + + plt.subplot(2, 2, 1) + show(image_A) + plt.subplot(2, 2, 2) + show(image_B) + plt.subplot(2, 2, 3) + show(res["warped_image_A"]) + + phi_AB_vectorfield = res["phi_AB"](net.identity_map) + plt.contour(torchvision.utils.make_grid(phi_AB_vectorfield[:6], nrow=3)[0].cpu().detach()) + plt.contour(torchvision.utils.make_grid(phi_AB_vectorfield[:6], nrow=3)[1].cpu().detach()) + plt.subplot(2, 2, 4) + show(res["warped_image_A"]- image_B) + plt.tight_layout() + footsteps.plot() diff --git a/ICON/pyproject.toml b/ICON/pyproject.toml new file mode 100644 index 00000000..b3887988 --- /dev/null +++ b/ICON/pyproject.toml @@ -0,0 +1,6 @@ +[build-system] +requires = [ + "setuptools>42", + "wheel" +] +build-backend = "setuptools.build_meta" diff --git a/ICON/requirements.txt b/ICON/requirements.txt new file mode 100644 index 00000000..39d9d99f --- /dev/null +++ b/ICON/requirements.txt @@ -0,0 +1,9 @@ +torch +torchvision +tensorboard +tqdm +matplotlib +monai +footsteps>=0.1.6 +itk==5.3.0 +girder_client==3.1.8 diff --git a/ICON/setup.cfg b/ICON/setup.cfg new file mode 100644 index 00000000..30bb9b51 --- /dev/null +++ b/ICON/setup.cfg @@ -0,0 +1,30 @@ +[metadata] +name = icon +version = 1.1.2 +author = Hastings Greer +author_email = t@hgreer.com +description = A package for image registration regularized by inverse consistency +long_description = file: README.md +long_description_content_type = text/markdown +classifiers = + Programming Language :: Python :: 3 + License :: OSI Approved :: Apache Software License + Operating System :: POSIX :: Linux + +[options] +packages = icon_registration +python_requires = >=3.4 + +install_requires = + torch + torchvision + tensorboard + tqdm + matplotlib + monai + footsteps>=0.1.6 + itk==5.3.0 + girder_client==3.1.8 + +[options.packages.find] +where = src diff --git a/ICON/setup.py b/ICON/setup.py new file mode 100644 index 00000000..60684932 --- /dev/null +++ b/ICON/setup.py @@ -0,0 +1,3 @@ +from setuptools import setup + +setup() diff --git a/ICON/test/__init__.py b/ICON/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ICON/test/test_2d_registration_train.py b/ICON/test/test_2d_registration_train.py new file mode 100644 index 00000000..385ec8b6 --- /dev/null +++ b/ICON/test/test_2d_registration_train.py @@ -0,0 +1,119 @@ +import unittest + + +class Test2DRegistrationTrain(unittest.TestCase): + def test_2d_registration_train_ICON(self): + import icon_registration + + import icon_registration.data as data + import icon_registration.networks as networks + from icon_registration import SSD + import icon_registration.visualize + + import numpy as np + import torch + import random + import os + + random.seed(1) + torch.manual_seed(1) + torch.cuda.manual_seed(1) + np.random.seed(1) + + batch_size = 128 + + d1, d2 = data.get_dataset_triangles( + data_size=50, hollow=False, batch_size=batch_size + ) + d1_t, d2_t = data.get_dataset_triangles( + data_size=50, hollow=False, batch_size=batch_size + ) + + lmbda = 2048 + + print("ICON training") + net = icon_registration.ICON( + icon_registration.DisplacementField(networks.tallUNet2(dimension=2)), + # Our image similarity metric. The last channel of x and y is whether the value is interpolated or extrapolated, + # which is used by some metrics but not this one + SSD(), + lmbda, + ) + + input_shape = list(next(iter(d1))[0].size()) + input_shape[0] = 1 + net.assign_identity_map(input_shape) + print(net.identity_map) + net.cuda() + optimizer = torch.optim.Adam(net.parameters(), lr=0.001) + net.train() + + y = icon_registration.train_datasets(net, optimizer, d1, d2, epochs=5) + import matplotlib.pyplot as plt + import footsteps + plt.plot([step["similarity_loss"] for step in y]) + footsteps.plot("2d-icon-similarity") + + test_A, test_B = [next(iter(d1))[0].cuda() for _ in range(2)] + icon_registration.visualize.plot_registration_result(test_A, test_B, net) + + # Test that image similarity is good enough + self.assertLess(np.mean(np.array([step["similarity_loss"] for step in y])[-5:]), 0.1) + + # Test that folds are rare enough + self.assertLess(np.mean(np.exp(np.array(y)[-5:, 3] - 0.1)), 2) + for l in y: + print(l) + + def test_2d_registration_train_GradICON(self): + import icon_registration + + import icon_registration.data as data + import icon_registration.networks as networks + from icon_registration import SSD + + import numpy as np + import torch + import random + import os + + random.seed(1) + torch.manual_seed(1) + torch.cuda.manual_seed(1) + np.random.seed(1) + + batch_size = 128 + + d1, d2 = data.get_dataset_triangles( + data_size=50, hollow=False, batch_size=batch_size + ) + d1_t, d2_t = data.get_dataset_triangles( + data_size=50, hollow=False, batch_size=batch_size + ) + + lmbda = 1.0 + + print("GradICON training") + net = icon_registration.GradICON( + icon_registration.DisplacementField(networks.tallUNet2(dimension=2)), + # Our image similarity metric. The last channel of x and y is whether the value is interpolated or extrapolated, + # which is used by some metrics but not this one + SSD(), + lmbda, + ) + + input_shape = next(iter(d1))[0].size() + net.assign_identity_map(input_shape) + net.cuda() + optimizer = torch.optim.Adam(net.parameters(), lr=0.001) + net.train() + + y = icon_registration.train_datasets(net, optimizer, d1, d2, epochs=5) + + # Test that image similarity is good enough + self.assertLess(np.mean(np.array(y)[-5:, 1]), 0.1) + + # Test that folds are rare enough + self.assertLess(np.mean(np.exp(np.array(y)[-5:, 3] - 0.1)), 2) + for l in y: + print(l) diff --git a/ICON/test/test_brain_itk.py b/ICON/test/test_brain_itk.py new file mode 100644 index 00000000..249dc92e --- /dev/null +++ b/ICON/test/test_brain_itk.py @@ -0,0 +1,83 @@ +import itk +import numpy as np +import unittest +import matplotlib.pyplot as plt +import numpy as np + +import icon_registration.test_utils +import icon_registration.pretrained_models +import icon_registration.itk_wrapper + + +class TestItkRegistration(unittest.TestCase): + def test_itk_registration(self): + print("brain GradICON") + import os + + os.environ["FOOTSTEPS_NAME"] = "test" + import footsteps + + icon_registration.test_utils.download_test_data() + + model = icon_registration.pretrained_models.brain_registration_model( + pretrained=True + ) + + image_A = itk.imread( + f"{icon_registration.test_utils.TEST_DATA_DIR}/brain_test_data/2_T1w_acpc_dc_restore_brain.nii.gz" + ) + + image_B = itk.imread( + f"{icon_registration.test_utils.TEST_DATA_DIR}/brain_test_data/8_T1w_acpc_dc_restore_brain.nii.gz" + ) + + image_A_processed = ( + icon_registration.pretrained_models.brain_network_preprocess(image_A) + ) + + image_B_processed = ( + icon_registration.pretrained_models.brain_network_preprocess(image_B) + ) + + phi_AB, phi_BA = icon_registration.itk_wrapper.register_pair( + model, image_A_processed, image_B_processed + ) + + assert isinstance(phi_AB, itk.CompositeTransform) + interpolator = itk.LinearInterpolateImageFunction.New(image_A) + + warped_image_A = itk.resample_image_filter( + image_A_processed, + transform=phi_AB, + interpolator=interpolator, + size=itk.size(image_B), + output_spacing=itk.spacing(image_B), + output_direction=image_B.GetDirection(), + output_origin=image_B.GetOrigin(), + ) + + plt.imshow( + np.array(itk.checker_board_image_filter(warped_image_A, image_B_processed))[ + 40 + ] + ) + plt.colorbar() + plt.savefig(footsteps.output_dir + "grid.png") + plt.clf() + plt.imshow(np.array(warped_image_A)[40]) + plt.savefig(footsteps.output_dir + "warped.png") + plt.clf() + + reference = np.load( + icon_registration.test_utils.TEST_DATA_DIR + / "brain_test_data/2_and_8_warped_itkfix.npy" + ) + + np.save( + footsteps.output_dir + "warped_brain.npy", + itk.array_from_image(warped_image_A)[40], + ) + + self.assertLess( + np.mean(np.abs(reference - itk.array_from_image(warped_image_A)[40])), 1e-5 + ) diff --git a/ICON/test/test_imports.py b/ICON/test/test_imports.py new file mode 100644 index 00000000..535097c4 --- /dev/null +++ b/ICON/test/test_imports.py @@ -0,0 +1,28 @@ +import unittest + + +class TestImports(unittest.TestCase): + def test_imports_CPU(self): + import icon_registration + import icon_registration.networks + import icon_registration.data + + def test_requirements_match_cfg_CPU(self): + from inspect import getsourcefile + import os.path as path, sys + import configparser + + current_dir = path.dirname(path.abspath(getsourcefile(lambda: 0))) + parent_dir = current_dir[: current_dir.rfind(path.sep)] + + with open(parent_dir + "/requirements.txt") as f: + requirements_txt = "\n" + f.read() + requirements_cfg = configparser.ConfigParser() + requirements_cfg.read(parent_dir + "/setup.cfg") + requirements_cfg = requirements_cfg["options"]["install_requires"] + "\n" + self.assertEqual(requirements_txt, requirements_cfg) + + def test_pytorch_cuda(self): + import torch + + x = torch.Tensor([7]).cuda diff --git a/ICON/test/test_knee_itk.py b/ICON/test/test_knee_itk.py new file mode 100644 index 00000000..81b67624 --- /dev/null +++ b/ICON/test/test_knee_itk.py @@ -0,0 +1,128 @@ +import itk +import numpy as np +import unittest +import matplotlib.pyplot as plt +import numpy as np + +import icon_registration.test_utils +import icon_registration.pretrained_models +import icon_registration.itk_wrapper + + +class TestItkRegistration(unittest.TestCase): + def test_itk_registration(self): + import os + + os.environ["FOOTSTEPS_NAME"] = "test" + import footsteps + + icon_registration.test_utils.download_test_data() + + model = icon_registration.pretrained_models.OAI_knees_registration_model( + pretrained=True + ) + + image_A = itk.imread( + str( + icon_registration.test_utils.TEST_DATA_DIR + / "knees_diverse_sizes" + / + # "9126260_20060921_SAG_3D_DESS_LEFT_11309302_image.nii.gz") + "9487462_20081003_SAG_3D_DESS_RIGHT_11495603_image.nii.gz" + ) + ) + + image_B = itk.imread( + str( + icon_registration.test_utils.TEST_DATA_DIR + / "knees_diverse_sizes" + / "9225063_20090413_SAG_3D_DESS_RIGHT_12784112_image.nii.gz" + ) + ) + print(image_A.GetLargestPossibleRegion().GetSize()) + print(image_B.GetLargestPossibleRegion().GetSize()) + print(image_A.GetSpacing()) + print(image_B.GetSpacing()) + + phi_AB, phi_BA = icon_registration.itk_wrapper.register_pair( + model, image_A, image_B + ) + + assert isinstance(phi_AB, itk.CompositeTransform) + interpolator = itk.LinearInterpolateImageFunction.New(image_A) + + warped_image_A = itk.resample_image_filter( + image_A, + transform=phi_AB, + interpolator=interpolator, + size=itk.size(image_B), + output_spacing=itk.spacing(image_B), + output_direction=image_B.GetDirection(), + output_origin=image_B.GetOrigin(), + ) + + plt.imshow( + np.array(itk.checker_board_image_filter(warped_image_A, image_B))[40] + ) + plt.colorbar() + plt.savefig(footsteps.output_dir + "grid.png") + plt.clf() + plt.imshow(np.array(warped_image_A)[40]) + plt.savefig(footsteps.output_dir + "warped.png") + plt.clf() + + reference = np.load( + icon_registration.test_utils.TEST_DATA_DIR / "warped_itkfix.npy" + ) + + np.save( + footsteps.output_dir + "warped_knee.npy", + itk.array_from_image(warped_image_A)[40], + ) + + self.assertLess( + np.mean(np.abs(reference - itk.array_from_image(warped_image_A)[40])), 1e-6 + ) + + def test_itk_consistency(self): + import torch + + img = torch.tensor([[1, 2, 3], [4, 5.0, 6], [7, 8, 13]])[None, None] + + warper = icon_registration.RegistrationModule() + + warper.assign_identity_map([1, 1, 3, 3]) + + deformation = torch.zeros((1, 2, 3, 3)) + + deformation[0, 0, 0, 0] = 0.5 + + deformation[0, 1, 2, 2] = -0.5 + + img_func = warper.as_function(img) + + warped_image_torch = img_func(warper.identity_map + deformation) + + itk_img = itk.image_from_array(img[0, 0]) + + sp = itk.Vector[itk.D, 2]((4.2, 6.9)) + + itk_img.SetSpacing(sp) + + phi = icon_registration.itk_wrapper.create_itk_transform( + deformation + warper.identity_map, warper.identity_map, itk_img, itk_img + ) + + interpolator = itk.LinearInterpolateImageFunction.New(itk_img) + + warped_image_A = itk.resample_image_filter( + itk_img, + transform=phi, + interpolator=interpolator, + use_reference_image=True, + reference_image=itk_img, + ) + + self.assertTrue( + np.all(np.array(warped_image_A) == np.array(warped_image_torch)) + ) diff --git a/ICON/test/test_knee_registration.py b/ICON/test/test_knee_registration.py new file mode 100644 index 00000000..609ceb42 --- /dev/null +++ b/ICON/test/test_knee_registration.py @@ -0,0 +1,183 @@ +import unittest + + +class TestKneeRegistration(unittest.TestCase): + def test_knee_registration(self): + print("OAI ICON") + + import icon_registration.pretrained_models + from icon_registration.mermaidlite import compute_warped_image_multiNC + from icon_registration.losses import flips + + import torch + import numpy as np + + import subprocess + + print("Downloading test data)") + import icon_registration.test_utils + + icon_registration.test_utils.download_test_data() + t_ds = torch.load( + icon_registration.test_utils.TEST_DATA_DIR / "icon_example_data" + ) + batched_ds = list(zip(*[t_ds[i::2] for i in range(2)])) + net = icon_registration.pretrained_models.OAI_knees_registration_model( + pretrained=True + ) + # Run on the four downloaded image pairs + + with torch.no_grad(): + dices = [] + folds_list = [] + + for x in batched_ds[:]: + # Seperate the image data used for registration from the segmentation used for evaluation, + # and shape it for passing to the network + x = list(zip(*x)) + x = [torch.cat(r, 0).cuda().float() for r in x] + fixed_image, fixed_cartilage = x[0], x[2] + moving_image, moving_cartilage = x[1], x[3] + + # Run the registration. + # Our network expects batches of two pairs, + # moving_image.size = torch.Size([2, 1, 80, 192, 192]) + # fixed_image.size = torch.Size([2, 1, 80, 192, 192]) + # intensity normalized to have min 0 and max 1. + + net(moving_image, fixed_image) + + # Once registration is run, net.phi_AB and net.phi_BA are functions that map + # tensors of coordinates from image B to A and A to B respectively. + + # Evaluate the registration + # First, evaluate phi_AB on a tensor of coordinates to get an explicit map. + phi_AB_vectorfield = net.phi_AB(net.identity_map) + fat_phi = torch.nn.Upsample( + size=moving_cartilage.size()[2:], + mode="trilinear", + align_corners=False, + )(phi_AB_vectorfield[:, :3]) + sz = np.array(fat_phi.size()) + spacing = 1.0 / (sz[2::] - 1) + + # Warp the cartilage of one image to match the other using the explicit map. + warped_moving_cartilage = compute_warped_image_multiNC( + moving_cartilage.float(), fat_phi, spacing, 1 + ) + + # Binarize the segmentations + wmb = warped_moving_cartilage > 0.5 + fb = fixed_cartilage > 0.5 + + # Compute the dice metric + intersection = wmb * fb + dice = ( + 2 + * torch.sum(intersection, [1, 2, 3, 4]).float() + / (torch.sum(wmb, [1, 2, 3, 4]) + torch.sum(fb, [1, 2, 3, 4])) + ) + print("Batch DICE:", dice) + dices.append(dice) + + # Compute the folds metric + f = [flips(phi[None]).item() for phi in phi_AB_vectorfield] + print("Batch folds per image:", f) + folds_list.append(f) + + mean_dice = torch.mean(torch.cat(dices).cpu()) + print("Mean DICE SCORE:", mean_dice) + self.assertTrue(mean_dice.item() > 0.68) + mean_folds = np.mean(folds_list) + print("Mean folds per image:", mean_folds) + self.assertTrue(mean_folds < 300) + + def test_knee_registration_gradICON(self): + print("OAI gradICON") + + import icon_registration.pretrained_models + from icon_registration.mermaidlite import compute_warped_image_multiNC + from icon_registration.losses import flips + + import torch + import numpy as np + + import subprocess + + print("Downloading test data)") + import icon_registration.test_utils + + icon_registration.test_utils.download_test_data() + t_ds = torch.load( + icon_registration.test_utils.TEST_DATA_DIR / "icon_example_data" + ) + batched_ds = list(zip(*[t_ds[i::2] for i in range(2)])) + net = icon_registration.pretrained_models.OAI_knees_gradICON_model( + pretrained=True + ) + # Run on the four downloaded image pairs + + with torch.no_grad(): + dices = [] + folds_list = [] + + for x in batched_ds[:]: + # Seperate the image data used for registration from the segmentation used for evaluation, + # and shape it for passing to the network + x = list(zip(*x)) + x = [torch.cat(r, 0).cuda().float() for r in x] + fixed_image, fixed_cartilage = x[0], x[2] + moving_image, moving_cartilage = x[1], x[3] + + # Run the registration. + # Our network expects batches of two pairs, + # moving_image.size = torch.Size([2, 1, 80, 192, 192]) + # fixed_image.size = torch.Size([2, 1, 80, 192, 192]) + # intensity normalized to have min 0 and max 1. + + net(moving_image, fixed_image) + + # Once registration is run, net.phi_AB and net.phi_BA are functions that map + # tensors of coordinates from image B to A and A to B respectively. + + # Evaluate the registration + # First, evaluate phi_AB on a tensor of coordinates to get an explicit map. + phi_AB_vectorfield = net.phi_AB(net.identity_map) + fat_phi = torch.nn.Upsample( + size=moving_cartilage.size()[2:], + mode="trilinear", + align_corners=False, + )(phi_AB_vectorfield[:, :3]) + sz = np.array(fat_phi.size()) + spacing = 1.0 / (sz[2::] - 1) + + # Warp the cartilage of one image to match the other using the explicit map. + warped_moving_cartilage = compute_warped_image_multiNC( + moving_cartilage.float(), fat_phi, spacing, 1 + ) + + # Binarize the segmentations + wmb = warped_moving_cartilage > 0.5 + fb = fixed_cartilage > 0.5 + + # Compute the dice metric + intersection = wmb * fb + dice = ( + 2 + * torch.sum(intersection, [1, 2, 3, 4]).float() + / (torch.sum(wmb, [1, 2, 3, 4]) + torch.sum(fb, [1, 2, 3, 4])) + ) + print("Batch DICE:", dice) + dices.append(dice) + + # Compute the folds metric + f = [flips(phi[None]).item() for phi in phi_AB_vectorfield] + print("Batch folds per image:", f) + folds_list.append(f) + + mean_dice = torch.mean(torch.cat(dices).cpu()) + print("Mean DICE SCORE:", mean_dice) + self.assertTrue(mean_dice.item() > 0.68) + mean_folds = np.mean(folds_list) + print("Mean folds per image:", mean_folds) + self.assertTrue(mean_folds < 300) diff --git a/ICON/test/test_lung_itk.py b/ICON/test/test_lung_itk.py new file mode 100644 index 00000000..90a1c850 --- /dev/null +++ b/ICON/test/test_lung_itk.py @@ -0,0 +1,125 @@ +import itk +import numpy as np +import unittest +import matplotlib.pyplot as plt +import numpy as np + +import icon_registration.test_utils +import icon_registration.pretrained_models +import icon_registration.itk_wrapper + + +class TestItkRegistration(unittest.TestCase): + def test_itk_registration(self): + model = icon_registration.pretrained_models.LungCT_registration_model( + pretrained=True + ) + + icon_registration.test_utils.download_test_data() + + image_exp = itk.imread( + str( + icon_registration.test_utils.TEST_DATA_DIR + / "lung_test_data/copd1_highres_EXP_STD_COPD_img.nii.gz" + ) + ) + image_insp = itk.imread( + str( + icon_registration.test_utils.TEST_DATA_DIR + / "lung_test_data/copd1_highres_INSP_STD_COPD_img.nii.gz" + ) + ) + image_exp_seg = itk.imread( + str( + icon_registration.test_utils.TEST_DATA_DIR + / "lung_test_data/copd1_highres_EXP_STD_COPD_label.nii.gz" + ) + ) + image_insp_seg = itk.imread( + str( + icon_registration.test_utils.TEST_DATA_DIR + / "lung_test_data/copd1_highres_INSP_STD_COPD_label.nii.gz" + ) + ) + + image_insp_preprocessed = ( + icon_registration.pretrained_models.lung_network_preprocess( + image_insp, image_insp_seg + ) + ) + image_exp_preprocessed = ( + icon_registration.pretrained_models.lung_network_preprocess( + image_exp, image_exp_seg + ) + ) + + phi_AB, phi_BA = icon_registration.itk_wrapper.register_pair( + model, image_insp_preprocessed, image_exp_preprocessed, finetune_steps=None + ) + + assert isinstance(phi_AB, itk.CompositeTransform) + interpolator = itk.LinearInterpolateImageFunction.New(image_insp_preprocessed) + + warped_image_insp_preprocessed = itk.resample_image_filter( + image_insp_preprocessed, + transform=phi_AB, + interpolator=interpolator, + size=itk.size(image_exp_preprocessed), + output_spacing=itk.spacing(image_exp_preprocessed), + output_direction=image_exp_preprocessed.GetDirection(), + output_origin=image_exp_preprocessed.GetOrigin(), + ) + + # log some images to show the registration + import os + + os.environ["FOOTSTEPS_NAME"] = "test" + import footsteps + + plt.imshow( + np.array( + itk.checker_board_image_filter( + warped_image_insp_preprocessed, image_exp_preprocessed + ) + )[140] + ) + plt.colorbar() + plt.savefig(footsteps.output_dir + "grid_lung.png") + plt.clf() + plt.imshow(np.array(warped_image_insp_preprocessed)[140]) + plt.colorbar() + plt.savefig(footsteps.output_dir + "warped_lung.png") + plt.clf() + plt.imshow( + np.array(warped_image_insp_preprocessed)[140] + - np.array(image_exp_preprocessed)[140] + ) + plt.colorbar() + plt.savefig(footsteps.output_dir + "difference_lung.png") + plt.clf() + + insp_points = icon_registration.test_utils.read_copd_pointset( + "test_files/lung_test_data/copd1_300_iBH_xyz_r1.txt" + ) + exp_points = icon_registration.test_utils.read_copd_pointset( + "test_files/lung_test_data/copd1_300_eBH_xyz_r1.txt" + ) + dists = [] + for i in range(len(insp_points)): + px, py = ( + exp_points[i], + np.array(phi_BA.TransformPoint(tuple(insp_points[i]))), + ) + dists.append(np.sqrt(np.sum((px - py) ** 2))) + print(np.mean(dists)) + + self.assertLess(np.mean(dists), 1.5) + dists = [] + for i in range(len(insp_points)): + px, py = ( + insp_points[i], + np.array(phi_AB.TransformPoint(tuple(exp_points[i]))), + ) + dists.append(np.sqrt(np.sum((px - py) ** 2))) + print(np.mean(dists)) + self.assertLess(np.mean(dists), 2.3) diff --git a/ICON/test/test_lung_registration.py b/ICON/test/test_lung_registration.py new file mode 100644 index 00000000..f998264a --- /dev/null +++ b/ICON/test/test_lung_registration.py @@ -0,0 +1,308 @@ +import os +import unittest + +import itk +import numpy as np +import torch +import torch.nn.functional as F +import tqdm + +import icon_registration +import icon_registration.pretrained_models + +COPD_spacing = { + "copd1": [0.625, 0.625, 2.5], + "copd2": [0.645, 0.645, 2.5], + "copd3": [0.652, 0.652, 2.5], + "copd4": [0.590, 0.590, 2.5], + "copd5": [0.647, 0.647, 2.5], + "copd6": [0.633, 0.633, 2.5], + "copd7": [0.625, 0.625, 2.5], + "copd8": [0.586, 0.586, 2.5], + "copd9": [0.664, 0.664, 2.5], + "copd10": [0.742, 0.742, 2.5], +} + + +def readPoint(f_path): + """ + :param f_path: the path to the file containing the position of points. + Points are deliminated by '\n' and X,Y,Z of each point are deliminated by '\t'. + :return: numpy list of positions. + """ + with open(f_path) as fp: + content = fp.read().split("\n") + + # Read number of points from second + count = len(content) - 1 + + # Read the points + points = np.ndarray([count, 3], dtype=np.float64) + + for i in range(count): + if content[i] == "": + break + temp = content[i].split("\t") + points[i, 0] = float(temp[0]) + points[i, 1] = float(temp[1]) + points[i, 2] = float(temp[2]) + + return points + + +def calc_warped_points(source_list_t, phi_t, dim, spacing, phi_spacing): + """ + :param source_list_t: source image. + :param phi_t: the inversed displacement. Domain in source coordinate. + :param dim: voxel dimenstions. + :param spacing: image spacing. + :return: a N*3 tensor containg warped positions in the physical coordinate. + """ + warped_list_t = F.grid_sample(phi_t, source_list_t, align_corners=True) + + warped_list_t = torch.flip(warped_list_t.permute(0, 2, 3, 4, 1), [4])[0, 0, 0] + warped_list_t = torch.mul( + torch.mul(warped_list_t, torch.from_numpy(dim - 1.0)), + torch.from_numpy(phi_spacing), + ) + + return warped_list_t + + +def eval_with_data(source_list, target_list, phi, dim, spacing, origin, phi_spacing): + """ + :param source_list: a numpy list of markers' position in source image. + :param target_list: a numpy list of markers' position in target image. + :param phi: displacement map in numpy format. + :param dim: voxel dimenstions. + :param spacing: image spacing. + :param return: res, [dist_x, dist_y, dist_z] res is the distance between + the warped points and target points in MM. [dist_x, dist_y, dist_z] are + distances in MM along x,y,z axis perspectively. + """ + origin_list = np.repeat( + [ + origin, + ], + target_list.shape[0], + axis=0, + ) + + # Translate landmark from landmark coord to phi coordinate + target_list_t = ( + torch.from_numpy((target_list - 1.0) * spacing) - origin_list * phi_spacing + ) + source_list_t = ( + torch.from_numpy((source_list - 1.0) * spacing) - origin_list * phi_spacing + ) + + # Pay attention to the definition of align_corners in grid_sampling. + # Translate landmarks to voxel index in image space [-1, 1] + source_list_norm = source_list_t / phi_spacing / (dim - 1.0) * 2.0 - 1.0 + source_list_norm = source_list_norm.unsqueeze(0).unsqueeze(0).unsqueeze(0) + + phi_t = torch.from_numpy(phi).double() + + warped_list_t = calc_warped_points( + source_list_norm, phi_t, dim, spacing, phi_spacing + ) + + pdist = torch.nn.PairwiseDistance(p=2) + dist = pdist(target_list_t, warped_list_t) + idx = torch.argsort(dist).numpy() + dist_x = torch.mean(torch.abs(target_list_t[:, 0] - warped_list_t[:, 0])).item() + dist_y = torch.mean(torch.abs(target_list_t[:, 1] - warped_list_t[:, 1])).item() + dist_z = torch.mean(torch.abs(target_list_t[:, 2] - warped_list_t[:, 2])).item() + res = torch.mean(dist).item() + + return res, [dist_x, dist_y, dist_z] + + +def compute_dice(x, y): + eps = 1e-11 + y_loc = set(np.where(y.flatten() == 1)[0]) + x_loc = set(np.where(x.flatten() == 1)[0]) + # iou + intersection = set.intersection(x_loc, y_loc) + # recall + len_intersection = len(intersection) + tp = float(len_intersection) + fn = float(len(y_loc) - len_intersection) + fp = float(len(x_loc) - len_intersection) + + if len(y_loc) != 0 or len(x_loc) != 0: + return 2 * tp / (2 * tp + fn + fp + eps) + + return 0.0 + + +def eval_copd_highres(seg_folder, case_id, phi, phi_inv, phi_spacing, origin, dim): + """ + :param dataset_path: the path to the dataset folder. The folder structure assumption: + dataset_path/landmarks stores landmarks files; dataset_path/segments stores segmentation maps + + :param case_id: a list of case id + :param phi: phi defined in expiration image domain. Numpy array. Bx3xDxWxH. Orientation: SI, AP, RL orientation + :param phi_inv: phi_inv defined in inspiration image domain. Numpy array. Bx3xDxWxH + :param phi_spacing: physical spacing of the domain of phi. Numpy array. Bx3xDxWxH + """ + + def _eval(phi, case_id, phi_spacing, origin, dim, inv=False): + result = {} + if inv: + source_file = os.path.join(seg_folder, f"{case_id}_300_iBH_xyz_r1.txt") + target_file = os.path.join(seg_folder, f"{case_id}_300_eBH_xyz_r1.txt") + else: + source_file = os.path.join(seg_folder, f"{case_id}_300_eBH_xyz_r1.txt") + target_file = os.path.join(seg_folder, f"{case_id}_300_iBH_xyz_r1.txt") + + spacing = COPD_spacing[case_id] + + source_list = readPoint(source_file) + target_list = readPoint(target_file) + + mTRE, mTRE_seperate = eval_with_data( + source_list, target_list, phi, dim, spacing, origin, phi_spacing + ) + + result["mTRE"] = mTRE + result["mTRE_X"] = mTRE_seperate[0] + result["mTRE_Y"] = mTRE_seperate[1] + result["mTRE_Z"] = mTRE_seperate[2] + + return result + + results = {} + for i, case in enumerate(case_id): + results[case] = _eval( + phi[i : i + 1], case, phi_spacing[i], origin[i], dim[i], inv=False + ) + results[f"{case}_inv"] = _eval( + phi_inv[i : i + 1], case, phi_spacing[i], origin[i], dim[i], inv=True + ) + return results + + +class TestLungRegistration(unittest.TestCase): + def test_lung_registration(self): + print("lung gradICON") + + net = icon_registration.pretrained_models.LungCT_registration_model( + pretrained=True + ) + icon_registration.test_utils.download_test_data() + + root = str(icon_registration.test_utils.TEST_DATA_DIR / "lung_test_data") + cases = [f"copd{i}_highres" for i in range(1, 2)] + hu_clip_range = [-1000, 0] + + # Load data + def process(iA, isSeg=False): + iA = iA[None, None, :, :, :] + if isSeg: + iA = iA.float() + iA = torch.nn.functional.max_pool3d(iA, 2) + iA[iA > 0] = 1 + else: + iA = torch.clip(iA, hu_clip_range[0], hu_clip_range[1]) + iA = iA / 1000 + + iA = torch.nn.functional.avg_pool3d(iA, 2) + return iA + + def read_itk(path): + img = itk.imread(path) + return torch.tensor(np.asarray(img)), np.flipud(list(img.GetOrigin())) + + dirlab = [] + dirlab_seg = [] + dirlab_origin = [] + for name in tqdm.tqdm(list(iter(cases))[:]): + image_insp, _ = read_itk(f"{root}/{name}_INSP_STD_COPD_img.nii.gz") + image_exp, _ = read_itk(f"{root}/{name}_EXP_STD_COPD_img.nii.gz") + seg_insp, _ = read_itk(f"{root}/{name}_INSP_STD_COPD_label.nii.gz") + seg_exp, origin = read_itk(f"{root}/{name}_EXP_STD_COPD_label.nii.gz") + + # dirlab.append((process(image_insp), process(image_exp))) + dirlab.append( + ( + ((process(image_insp) + 1) * process(seg_insp, True)), + ((process(image_exp) + 1) * process(seg_exp, True)), + ) + ) + dirlab_seg.append((process(seg_insp, True), process(seg_exp, True))) + dirlab_origin.append(origin) + + def make_batch(dataset, dataset_seg): + image_A = torch.cat([p[0] for p in dataset]).cuda() + image_B = torch.cat([p[1] for p in dataset]).cuda() + + image_A_seg = torch.cat([p[0] for p in dataset_seg]).cuda() + image_B_seg = torch.cat([p[1] for p in dataset_seg]).cuda() + return image_A, image_B, image_A_seg, image_B_seg + + net.cuda() + + train_A, train_B, train_A_seg, train_B_seg = make_batch(dirlab, dirlab_seg) + + phis = [] + phis_inv = [] + warped_A = [] + for i in range(train_A.shape[0]): + with torch.no_grad(): + print(net(train_A[i : i + 1], train_B[i : i + 1])) + + phis.append((net.phi_AB_vectorfield.detach() * 2.0 - 1.0)) + phis_inv.append((net.phi_BA_vectorfield.detach() * 2.0 - 1.0)) + warped_A.append(net.warped_image_A[:, 0:1].detach()) + + warped_A = torch.cat(warped_A) + + phis_np = (torch.cat(phis).cpu().numpy() + 1.0) / 2.0 + phis_inv_np = (torch.cat(phis_inv).cpu().numpy() + 1.0) / 2.0 + res = eval_copd_highres( + root, + [c[:-8] for c in cases], + phis_np, + phis_inv_np, + np.repeat(np.array([[1.0, 1.0, 1.0]]), len(cases), axis=0), + np.array([np.flipud(i) for i in dirlab_origin]), + np.repeat(np.array([[350, 350, 350]]), len(cases), axis=0), + ) + + results = [] + results_inv = [] + for k, v in res.items(): + result = [] + for m, n in v.items(): + result.append(n) + if "_inv" in k: + results_inv.append(result) + else: + results.append(result) + results = np.array(results) + + results_str = f"mTRE: {results[:,0].mean()}, mTRE_X: {results[:,1].mean()}, mTRE_Y: {results[:,2].mean()}, mTRE_Z: {results[:,3].mean()}" + print(results_str) + results_inv = np.array(results_inv) + results_inv_str = f"mTRE: {results_inv[:,0].mean()}, mTRE_X: {results_inv[:,1].mean()}, mTRE_Y: {results_inv[:,2].mean()}, mTRE_Z: {results_inv[:,3].mean()}" + print(results_inv_str) + + # Compute Dice + + warped_train_A_seg = F.grid_sample( + train_A_seg.float(), + torch.cat(phis).flip([1]).permute([0, 2, 3, 4, 1]), + padding_mode="zeros", + mode="nearest", + align_corners=True, + ) + dices = [] + for i in range(warped_train_A_seg.shape[0]): + dices.append( + compute_dice( + train_B_seg[i].cpu().numpy(), + warped_train_A_seg[i].detach().cpu().numpy(), + ) + ) + print(np.array(dices).mean())