From 76a40f9ca27fcdd35c40345c2ff8b5cd2b1ce88f Mon Sep 17 00:00:00 2001 From: Hastings Greer Date: Thu, 14 Dec 2023 08:04:00 -0500 Subject: [PATCH 01/15] initial commit for GradICON, ConstrICON submission --- ICON/icon_registration/__init__.py | 21 + ICON/icon_registration/config.py | 6 + ICON/icon_registration/data.py | 393 ++++++++ ICON/icon_registration/itk_wrapper.py | 235 +++++ ICON/icon_registration/losses.py | 505 +++++++++++ ICON/icon_registration/network_wrappers.py | 238 +++++ ICON/icon_registration/networks.py | 839 ++++++++++++++++++ .../pretrained_models/HCP_brain.py | 15 + .../pretrained_models/OAI_knees.py | 67 ++ .../pretrained_models/__init__.py | 3 + .../pretrained_models/lung_ct.py | 104 +++ ICON/icon_registration/registration_module.py | 114 +++ ICON/icon_registration/similarity.py | 234 +++++ ICON/icon_registration/test_utils.py | 64 ++ ICON/icon_registration/train.py | 121 +++ 15 files changed, 2959 insertions(+) create mode 100644 ICON/icon_registration/__init__.py create mode 100644 ICON/icon_registration/config.py create mode 100644 ICON/icon_registration/data.py create mode 100644 ICON/icon_registration/itk_wrapper.py create mode 100644 ICON/icon_registration/losses.py create mode 100644 ICON/icon_registration/network_wrappers.py create mode 100644 ICON/icon_registration/networks.py create mode 100644 ICON/icon_registration/pretrained_models/HCP_brain.py create mode 100644 ICON/icon_registration/pretrained_models/OAI_knees.py create mode 100644 ICON/icon_registration/pretrained_models/__init__.py create mode 100644 ICON/icon_registration/pretrained_models/lung_ct.py create mode 100644 ICON/icon_registration/registration_module.py create mode 100644 ICON/icon_registration/similarity.py create mode 100644 ICON/icon_registration/test_utils.py create mode 100644 ICON/icon_registration/train.py diff --git a/ICON/icon_registration/__init__.py b/ICON/icon_registration/__init__.py new file mode 100644 index 00000000..b9488792 --- /dev/null +++ b/ICON/icon_registration/__init__.py @@ -0,0 +1,21 @@ +from icon_registration.losses import ( + LNCC, + LNCCOnlyInterpolated, + BlurredSSD, + GradientICON, + InverseConsistentNet, + gaussian_blur, + ssd_only_interpolated, + ssd, + SSDOnlyInterpolated, + SSD, + NCC +) +from icon_registration.network_wrappers import ( + DownsampleRegistration, + FunctionFromMatrix, + FunctionFromVectorField, + RegistrationModule, + TwoStepRegistration, +) +from icon_registration.train import train_batchfunction, train_datasets diff --git a/ICON/icon_registration/config.py b/ICON/icon_registration/config.py new file mode 100644 index 00000000..bc7c42c0 --- /dev/null +++ b/ICON/icon_registration/config.py @@ -0,0 +1,6 @@ +import torch + +if torch.cuda.is_available(): + device = torch.device("cuda") +else: + device = torch.device("cpu") diff --git a/ICON/icon_registration/data.py b/ICON/icon_registration/data.py new file mode 100644 index 00000000..8504bf7c --- /dev/null +++ b/ICON/icon_registration/data.py @@ -0,0 +1,393 @@ +import random + +import numpy as np +import torch +import torch.nn.functional as F +import torchvision +import tqdm + +from icon_registration import config + + +def get_dataset_mnist(split, number=5): + ds = torch.utils.data.DataLoader( + torchvision.datasets.MNIST( + "./files/", + transform=torchvision.transforms.ToTensor(), + download=True, + train=(split == "train"), + ), + batch_size=500, + ) + images = [] + for _, batch in enumerate(ds): + label = np.array(batch[1]) + batch_nines = label == number + images.append(np.array(batch[0])[batch_nines]) + images = np.concatenate(images) + + ds = torch.utils.data.TensorDataset(torch.Tensor(images)) + d1, d2 = ( + torch.utils.data.DataLoader( + ds, + batch_size=128, + shuffle=True, + ) + for _ in (1, 1) + ) + return d1, d2 + + +def get_dataset_1d(data_size=128, samples=6000, batch_size=128): + x = np.arange(0, 1, 1 / data_size) + x = np.reshape(x, (1, data_size)) + cx = np.random.random((samples, 1)) * 0.3 + 0.4 + r = np.random.random((samples, 1)) * 0.2 + 0.2 + + circles = np.tanh(-40 * (np.sqrt((x - cx) ** 2) - r)) + + ds = torch.utils.data.TensorDataset(torch.Tensor(np.expand_dims(circles, 1))) + d1, d2 = ( + torch.utils.data.DataLoader( + ds, + batch_size=batch_size, + shuffle=True, + ) + for _ in (1, 1) + ) + return d1, d2 + + +def get_dataset_triangles( + split=None, data_size=128, hollow=False, samples=6000, batch_size=128 +): + x, y = np.mgrid[0 : 1 : data_size * 1j, 0 : 1 : data_size * 1j] + x = np.reshape(x, (1, data_size, data_size)) + y = np.reshape(y, (1, data_size, data_size)) + cx = np.random.random((samples, 1, 1)) * 0.3 + 0.4 + cy = np.random.random((samples, 1, 1)) * 0.3 + 0.4 + r = np.random.random((samples, 1, 1)) * 0.2 + 0.2 + theta = np.random.random((samples, 1, 1)) * np.pi * 2 + isTriangle = np.random.random((samples, 1, 1)) > 0.5 + + triangles = np.sqrt((x - cx) ** 2 + (y - cy) ** 2) - r * np.cos(np.pi / 3) / np.cos( + (np.arctan2(x - cx, y - cy) + theta) % (2 * np.pi / 3) - np.pi / 3 + ) + + triangles = np.tanh(-40 * triangles) + + circles = np.tanh(-40 * (np.sqrt((x - cx) ** 2 + (y - cy) ** 2) - r)) + if hollow: + triangles = 1 - triangles**2 + circles = 1 - circles**2 + + images = isTriangle * triangles + (1 - isTriangle) * circles + + ds = torch.utils.data.TensorDataset(torch.Tensor(np.expand_dims(images, 1))) + d1, d2 = ( + torch.utils.data.DataLoader( + ds, + batch_size=batch_size, + shuffle=True, + ) + for _ in (1, 1) + ) + return d1, d2 + + +def get_dataset_retina( + extra_deformation=False, + downsample_factor=4, + blur_sigma=None, + warps_per_pair=20, + fixed_vertical_offset=None, + include_boundary=False, +): + try: + import elasticdeform + import hub + except: + + raise Exception( + """the retina dataset requires the dependencies hub and elasticdeform. + Try pip install hub elasticdeform""" + ) + + ds_name = f"retina{extra_deformation}{downsample_factor}{blur_sigma}{warps_per_pair}{fixed_vertical_offset}{include_boundary}.trch" + + import os + + if os.path.exists(ds_name): + augmented_ds1_tensor, augmented_ds2_tensor = torch.load(ds_name) + else: + + res = [] + for batch in hub.load("hub://activeloop/drive-train").pytorch( + num_workers=0, batch_size=4, shuffle=False + ): + if include_boundary: + res.append(batch["manual_masks/mask"] ^ batch["masks/mask"]) + else: + res.append(batch["manual_masks/mask"]) + res = torch.cat(res) + ds_tensor = res[:, None, :, :, 0] * -1.0 + (not include_boundary) + + if fixed_vertical_offset is not None: + ds2_tensor = torch.cat( + [torch.zeros(20, 1, fixed_vertical_offset, 565), ds_tensor], axis=2 + ) + ds1_tensor = torch.cat( + [ds_tensor, torch.zeros(20, 1, fixed_vertical_offset, 565)], axis=2 + ) + else: + ds2_tensor = ds_tensor + ds1_tensor = ds_tensor + + warped_tensors = [] + print("warping images to generate dataset") + for _ in tqdm.tqdm(range(warps_per_pair)): + ds_2_list = [] + for el in ds2_tensor: + case = el[0] + # TODO implement random warping on gpu + case_warped = np.array(case) + if extra_deformation: + case_warped = elasticdeform.deform_random_grid( + case_warped, sigma=60, points=3 + ) + case_warped = elasticdeform.deform_random_grid( + case_warped, sigma=25, points=3 + ) + + case_warped = elasticdeform.deform_random_grid( + case_warped, sigma=12, points=6 + ) + ds_2_list.append(torch.tensor(case_warped)[None, None, :, :]) + ds_2_tensor = torch.cat(ds_2_list) + warped_tensors.append(ds_2_tensor) + + augmented_ds2_tensor = torch.cat(warped_tensors) + augmented_ds1_tensor = torch.cat([ds1_tensor for _ in range(warps_per_pair)]) + + torch.save((augmented_ds1_tensor, augmented_ds2_tensor), ds_name) + + batch_size = 10 + import torchvision.transforms.functional as Fv + + if blur_sigma is None: + ds1 = torch.utils.data.TensorDataset( + F.avg_pool2d(augmented_ds1_tensor, downsample_factor) + ) + else: + ds1 = torch.utils.data.TensorDataset( + Fv.gaussian_blur( + F.avg_pool2d(augmented_ds1_tensor, downsample_factor), + 4 * blur_sigma + 1, + blur_sigma, + ) + ) + d1 = torch.utils.data.DataLoader( + ds1, + batch_size=batch_size, + shuffle=False, + ) + if blur_sigma is None: + ds2 = torch.utils.data.TensorDataset( + F.avg_pool2d(augmented_ds2_tensor, downsample_factor) + ) + else: + ds2 = torch.utils.data.TensorDataset( + Fv.gaussian_blur( + F.avg_pool2d(augmented_ds2_tensor, downsample_factor), + 4 * blur_sigma + 1, + blur_sigma, + ) + ) + + d2 = torch.utils.data.DataLoader( + ds2, + batch_size=batch_size, + shuffle=False, + ) + + return d1, d2 + + +def get_dataset_sunnyside(split, scale=1): + import pickle + + with open("/playpen/tgreer/sunnyside.pickle", "rb") as f: + array = pickle.load(f) + if split == "train": + array = array[1000:] + elif split == "test": + array = array[:1000] + else: + raise ArgumentError() + + array = array[:, :, :, 0] + array = np.expand_dims(array, 1) + array = array * scale + array1 = array[::2] + array2 = array[1::2] + array12 = np.concatenate([array2, array1]) + array21 = np.concatenate([array1, array2]) + ds = torch.utils.data.TensorDataset(torch.Tensor(array21), torch.Tensor(array12)) + ds = torch.utils.data.DataLoader( + ds, + batch_size=128, + shuffle=True, + ) + return ds + + +def get_cartilage_dataset(): + cartilage = torch.load("/playpen/tgreer/cartilage_uint8s.trch") + return cartilage + + +def get_knees_dataset(): + brains = torch.load("/playpen/tgreer/kneestorch") + # with open("/playpen/tgreer/cartilage_eval_oriented", "rb") as f: + # cartilage = pickle.load(f) + + medbrains = [] + for b in brains: + medbrains.append(F.avg_pool3d(b, 4)) + + return brains, medbrains + +def get_copdgene_dataset(data_folder, cache_folder="./data_cache", lung_only=True, downscale=2): + ''' + This function load the preprocessed COPDGene train set. + ''' + import os + def process(iA, downscale, clamp=[-1000, 0], isSeg=False): + iA = iA[None, None, :, :, :] + #SI flip + iA = torch.flip(iA, dims=(2,)) + if isSeg: + iA = iA.float() + iA = torch.nn.functional.max_pool3d(iA, downscale) + iA[iA>0] = 1 + else: + iA = torch.clip(iA, clamp[0], clamp[1]) + clamp[0] + #TODO: For compatibility to the processed dataset(ranges between -1 to 0) used in paper, we subtract -1 here. + # Should remove -1 later. + iA = iA / torch.max(iA) - 1. + iA = torch.nn.functional.avg_pool3d(iA, downscale) + return iA + + cache_name = f"{cache_folder}/lungs_train_{downscale}xdown_scaled" + if os.path.exists(cache_name): + imgs = torch.load(cache_name, map_location='cpu') + if lung_only: + try: + masks = torch.load(f"{cache_folder}/lungs_seg_train_{downscale}xdown_scaled", map_location='cpu') + except FileNotFoundError: + print("Segmentation data not found.") + + else: + import itk + import glob + with open(f"{data_folder}/splits/train.txt") as f: + pair_paths = f.readlines() + imgs = [] + masks = [] + for name in tqdm.tqdm(list(iter(pair_paths))[:]): + name = name[:-1] # remove newline + + image_insp = torch.tensor(np.asarray(itk.imread(glob.glob(f"{data_folder} /{name}/{name}_INSP_STD*_COPD_img.nii.gz")[0]))) + image_exp= torch.tensor(np.asarray(itk.imread(glob.glob(f"{data_folder} /{name}/{name}_EXP_STD*_COPD_img.nii.gz")[0]))) + imgs.append((process(image_insp), process(image_exp))) + + seg_insp = torch.tensor(np.asarray(itk.imread(glob.glob(f"{data_folder} /{name}/{name}_INSP_STD*_COPD_label.nii.gz")[0]))) + seg_exp= torch.tensor(np.asarray(itk.imread(glob.glob(f"{data_folder} /{name}/{name}_EXP_STD*_COPD_label.nii.gz")[0]))) + masks.append((process(seg_insp, True), process(seg_exp, True))) + + torch.save(imgs, f"{cache_folder}/lungs_train_{downscale}xdown_scaled") + torch.save(masks, f"{cache_folder}/lungs_seg_train_{downscale}xdown_scaled") + + if lung_only: + imgs = torch.cat([(torch.cat(d, 1)+1)*torch.cat(m, 1) for d,m in zip(imgs, masks)], dim=0) + else: + imgs = torch.cat([torch.cat(d, 1)+1 for d in imgs], dim=0) + return torch.utils.data.TensorDataset(imgs) + +def get_learn2reg_AbdomenCTCT_dataset(data_folder, cache_folder="./data_cache", clamp=[-1000,0], downscale=1): + ''' + This function will return the training dataset of AbdomenCTCT registration task in learn2reg. + ''' + + # Check whether we have cached the dataset + import os + + cache_name = f"{cache_folder}/learn2reg_abdomenctct_train_set_clamp{clamp}scale{downscale}" + if os.path.exists(cache_name): + imgs = torch.load(cache_name) + else: + import json + import itk + import glob + with open(f"{data_folder}/AbdomenCTCT_dataset.json", 'r') as data_info: + data_info = json.loads(data_info.read()) + train_cases = [c["image"].split('/')[-1].split('.')[0] for c in data_info["training"]] + imgs = [np.asarray(itk.imread(glob.glob(data_folder + "/imagesTr/" + i + ".nii.gz")[0])) for i in train_cases] + + imgs = torch.Tensor(np.expand_dims(np.array(imgs), axis=1)).float() + imgs = (torch.clamp(imgs, clamp[0], clamp[1]) - clamp[0])/(clamp[1] - clamp[0]) + + # Cache the data + if not os.path.exists(cache_folder): + os.makedirs(cache_folder) + torch.save(imgs, cache_name) + + # Scale down the image + if downscale > 1: + imgs = F.avg_pool3d(imgs, downscale) + return torch.utils.data.TensorDataset(imgs) + +def get_learn2reg_lungCT_dataset(data_folder, cache_folder="./data_cache", lung_only=True, clamp=[-1000,0], downscale=1): + ''' + This function will return the training dataset of LungCT registration task in learn2reg. + ''' + import os + + cache_name = f"{cache_folder}/learn2reg_lung_train_set_lung_only" if lung_only else f"{cache_folder}/learn2reg_lung_train_set" + cache_name += f"_clamp{clamp}scale{downscale}" + if os.path.exists(cache_name): + imgs = torch.load(cache_name) + else: + import json + import itk + import glob + with open(f"{data_folder}/NLST_dataset.json", 'r') as data_info: + data_info = json.loads(data_info.read()) + train_pairs = [[p['fixed'].split('/')[-1], p['moving'].split('/')[-1]] for p in data_info["training_paired_images"]] + imgs = [] + for p in train_pairs: + img = np.array([np.asarray(itk.imread(glob.glob(data_folder + "/imagesTr/" + i)[0])) for i in p]) + if lung_only: + mask = np.array([np.asarray(itk.imread(glob.glob(data_folder + "/" + "/masksTr/" + i)[0])) for i in p]) + img = img * mask + clamp[0] * (1 - mask) + imgs.append(img) + + imgs = torch.Tensor(np.array(imgs)).float() + imgs = (torch.clamp(imgs, clamp[0], clamp[1]) - clamp[0])/(clamp[1] - clamp[0]) + + # Cache the data + if not os.path.exists(cache_folder): + os.makedirs(cache_folder) + torch.save(imgs, cache_name) + + # Scale down the image + if downscale > 1: + imgs = F.avg_pool3d(imgs, downscale) + return torch.utils.data.TensorDataset(imgs) + + +def make_batch(data, BATCH_SIZE, SCALE): + image = torch.cat([random.choice(data) for _ in range(BATCH_SIZE)]) + image = image.reshape(BATCH_SIZE, 1, SCALE * 40, SCALE * 96, SCALE * 96) + image = image.to(config.device) + return image diff --git a/ICON/icon_registration/itk_wrapper.py b/ICON/icon_registration/itk_wrapper.py new file mode 100644 index 00000000..422d91e0 --- /dev/null +++ b/ICON/icon_registration/itk_wrapper.py @@ -0,0 +1,235 @@ +import copy + +import itk +import numpy as np +import torch +import torch.nn.functional as F + +from icon_registration import config +from icon_registration.losses import to_floats + + +def finetune_execute(model, image_A, image_B, steps): + state_dict = copy.deepcopy(model.state_dict()) + optimizer = torch.optim.Adam(model.parameters(), lr=0.00002) + for _ in range(steps): + optimizer.zero_grad() + loss_tuple = model(image_A, image_B) + print(loss_tuple) + loss_tuple[0].backward() + optimizer.step() + with torch.no_grad(): + loss = model(image_A, image_B) + model.load_state_dict(state_dict) + return loss + + +def register_pair( + model, image_A, image_B, finetune_steps=None, return_artifacts=False +) -> "(itk.CompositeTransform, itk.CompositeTransform)": + + assert isinstance(image_A, itk.Image) + assert isinstance(image_B, itk.Image) + + # send model to cpu or gpu depending on config- auto detects capability + model.to(config.device) + + A_npy = np.array(image_A) + B_npy = np.array(image_B) + + assert(np.max(A_npy) != np.min(A_npy)) + assert(np.max(B_npy) != np.min(B_npy)) + # turn images into torch Tensors: add feature and batch dimensions (each of length 1) + A_trch = torch.Tensor(A_npy).to(config.device)[None, None] + B_trch = torch.Tensor(B_npy).to(config.device)[None, None] + + shape = model.identity_map.shape + + # Here we resize the input images to the shape expected by the neural network. This affects the + # pixel stride as well as the magnitude of the displacement vectors of the resulting + # displacement field, which create_itk_transform will have to compensate for. + A_resized = F.interpolate( + A_trch, size=shape[2:], mode="trilinear", align_corners=False + ) + B_resized = F.interpolate( + B_trch, size=shape[2:], mode="trilinear", align_corners=False + ) + if finetune_steps == 0: + raise Exception("To indicate no finetune_steps, pass finetune_steps=None") + + if finetune_steps == None: + with torch.no_grad(): + loss = model(A_resized, B_resized) + else: + loss = finetune_execute(model, A_resized, B_resized, finetune_steps) + + # phi_AB and phi_BA are [1, 3, H, W, D] pytorch tensors representing the forward and backward + # maps computed by the model + if hasattr(model, "prepare_for_viz"): + with torch.no_grad(): + model.prepare_for_viz(A_resized, B_resized) + phi_AB = model.phi_AB(model.identity_map) + phi_BA = model.phi_BA(model.identity_map) + + # the parameters ident, image_A, and image_B are used for their metadata + itk_transforms = ( + create_itk_transform(phi_AB, model.identity_map, image_A, image_B), + create_itk_transform(phi_BA, model.identity_map, image_B, image_A), + ) + if not return_artifacts: + return itk_transforms + else: + return itk_transforms + (to_floats(loss),) + +def register_pair_with_multimodalities( + model, image_A: list, image_B: list, finetune_steps=None, return_artifacts=False +) -> "(itk.CompositeTransform, itk.CompositeTransform)": + + assert len(image_A) == len(image_B), "image_A and image_B should have the same number of modalities." + + # send model to cpu or gpu depending on config- auto detects capability + model.to(config.device) + + A_npy, B_npy = [], [] + for image_a, image_b in zip(image_A, image_B): + assert isinstance(image_a, itk.Image) + assert isinstance(image_b, itk.Image) + + A_npy.append(np.array(image_a)) + B_npy.append(np.array(image_b)) + + assert(np.max(A_npy[-1]) != np.min(A_npy[-1])) + assert(np.max(B_npy[-1]) != np.min(B_npy[-1])) + + # turn images into torch Tensors: add batch dimensions (each of length 1) + A_trch = torch.Tensor(np.array(A_npy)).to(config.device)[None] + B_trch = torch.Tensor(np.array(B_npy)).to(config.device)[None] + + shape = model.identity_map.shape[2:] + if list(A_trch.shape[2:]) != list(shape) or (list(B_trch.shape[2:]) != list(shape)): + # Here we resize the input images to the shape expected by the neural network. This affects the + # pixel stride as well as the magnitude of the displacement vectors of the resulting + # displacement field, which create_itk_transform will have to compensate for. + A_trch = F.interpolate( + A_trch, size=shape, mode="trilinear", align_corners=False + ) + B_trch = F.interpolate( + B_trch, size=shape, mode="trilinear", align_corners=False + ) + + if finetune_steps == 0: + raise Exception("To indicate no finetune_steps, pass finetune_steps=None") + + if finetune_steps == None: + with torch.no_grad(): + loss = model(A_trch, B_trch) + else: + loss = finetune_execute(model, A_trch, B_trch, finetune_steps) + + # phi_AB and phi_BA are [1, 3, H, W, D] pytorch tensors representing the forward and backward + # maps computed by the model + if hasattr(model, "prepare_for_viz"): + with torch.no_grad(): + model.prepare_for_viz(A_trch, B_trch) + phi_AB = model.phi_AB(model.identity_map) + phi_BA = model.phi_BA(model.identity_map) + + # the parameters ident, image_A, and image_B are used for their metadata + itk_transforms = ( + create_itk_transform(phi_AB, model.identity_map, image_A[0], image_B[0]), + create_itk_transform(phi_BA, model.identity_map, image_B[0], image_A[0]), + ) + if not return_artifacts: + return itk_transforms + else: + return itk_transforms + (to_floats(loss),) + + +def create_itk_transform(phi, ident, image_A, image_B) -> "itk.CompositeTransform": + + # itk.DeformationFieldTransform expects a displacement field, so we subtract off the identity map. + disp = (phi - ident)[0].cpu() + + network_shape_list = list(ident.shape[2:]) + + dimension = len(network_shape_list) + + tr = itk.DisplacementFieldTransform[(itk.D, dimension)].New() + + # We convert the displacement field into an itk Vector Image. + scale = torch.Tensor(network_shape_list) + + for _ in network_shape_list: + scale = scale[:, None] + disp *= scale - 1 + + # disp is a shape [3, H, W, D] tensor with vector components in the order [vi, vj, vk] + disp_itk_format = ( + disp.double() + .numpy()[list(reversed(range(dimension)))] + .transpose(list(range(1, dimension + 1)) + [0]) + ) + # disp_itk_format is a shape [H, W, D, 3] array with vector components in the order [vk, vj, vi] + # as expected by itk. + + itk_disp_field = itk.image_from_array(disp_itk_format, is_vector=True) + + tr.SetDisplacementField(itk_disp_field) + + to_network_space = resampling_transform(image_A, list(reversed(network_shape_list))) + + from_network_space = resampling_transform( + image_B, list(reversed(network_shape_list)) + ).GetInverseTransform() + + phi_AB_itk = itk.CompositeTransform[itk.D, dimension].New() + + phi_AB_itk.PrependTransform(from_network_space) + phi_AB_itk.PrependTransform(tr) + phi_AB_itk.PrependTransform(to_network_space) + + # warp(image_A, phi_AB_itk) is close to image_B + + return phi_AB_itk + + +def resampling_transform(image, shape): + + imageType = itk.template(image)[0][itk.template(image)[1]] + + dummy_image = itk.image_from_array( + np.zeros(tuple(reversed(shape)), dtype=itk.array_from_image(image).dtype) + ) + if len(shape) == 2: + transformType = itk.MatrixOffsetTransformBase[itk.D, 2, 2] + else: + transformType = itk.VersorRigid3DTransform[itk.D] + initType = itk.CenteredTransformInitializer[transformType, imageType, imageType] + initializer = initType.New() + initializer.SetFixedImage(dummy_image) + initializer.SetMovingImage(image) + transform = transformType.New() + + initializer.SetTransform(transform) + initializer.InitializeTransform() + + if len(shape) == 3: + transformType = itk.CenteredAffineTransform[itk.D, 3] + t2 = transformType.New() + t2.SetCenter(transform.GetCenter()) + t2.SetOffset(transform.GetOffset()) + transform = t2 + m = transform.GetMatrix() + m_a = itk.array_from_matrix(m) + + input_shape = image.GetLargestPossibleRegion().GetSize() + + for i in range(len(shape)): + + m_a[i, i] = image.GetSpacing()[i] * (input_shape[i] / shape[i]) + + m_a = itk.array_from_matrix(image.GetDirection()) @ m_a + + transform.SetMatrix(itk.matrix_from_array(m_a)) + + return transform diff --git a/ICON/icon_registration/losses.py b/ICON/icon_registration/losses.py new file mode 100644 index 00000000..0cc0a2bb --- /dev/null +++ b/ICON/icon_registration/losses.py @@ -0,0 +1,505 @@ +from collections import namedtuple + +import matplotlib +import torch +import torch.nn.functional as F + +from icon_registration import config, network_wrappers + +from .mermaidlite import compute_warped_image_multiNC + + +def to_floats(stats): + out = [] + for v in stats: + if isinstance(v, torch.Tensor): + v = torch.mean(v).cpu().item() + out.append(v) + return ICONLoss(*out) + +class ICON(network_wrappers.RegistrationModule): + def __init__(self, network, similarity, lmbda): + + super().__init__() + + self.regis_net = network + self.lmbda = lmbda + self.similarity = similarity + + def __call__(self, image_A, image_B) -> ICONLoss: + return super().__call__(image_A, image_B) + + def forward(self, image_A, image_B): + + assert self.identity_map.shape[2:] == image_A.shape[2:] + assert self.identity_map.shape[2:] == image_B.shape[2:] + + # Tag used elsewhere for optimization. + # Must be set at beginning of forward b/c not preserved by .cuda() etc + self.identity_map.isIdentity = True + + self.phi_AB = self.regis_net(image_A, image_B) + self.phi_BA = self.regis_net(image_B, image_A) + + self.phi_AB_vectorfield = self.phi_AB(self.identity_map) + self.phi_BA_vectorfield = self.phi_BA(self.identity_map) + + if getattr(self.similarity, "isInterpolated", False): + # tag images during warping so that the similarity measure + # can use information about whether a sample is interpolated + # or extrapolated + inbounds_tag = torch.zeros([image_A.shape[0]] + [1] + list(image_A.shape[2:]), device=image_A.device) + if len(self.input_shape) - 2 == 3: + inbounds_tag[:, :, 1:-1, 1:-1, 1:-1] = 1.0 + elif len(self.input_shape) - 2 == 2: + inbounds_tag[:, :, 1:-1, 1:-1] = 1.0 + else: + inbounds_tag[:, :, 1:-1] = 1.0 + else: + inbounds_tag = None + + self.warped_image_A = compute_warped_image_multiNC( + torch.cat([image_A, inbounds_tag], axis=1) if inbounds_tag is not None else image_A, + self.phi_AB_vectorfield, + self.spacing, + 1, + ) + self.warped_image_B = compute_warped_image_multiNC( + torch.cat([image_B, inbounds_tag], axis=1) if inbounds_tag is not None else image_B, + self.phi_BA_vectorfield, + self.spacing, + 1, + ) + + similarity_loss = self.similarity( + self.warped_image_A, image_B + ) + self.similarity(self.warped_image_B, image_A) + + Iepsilon = ( + self.identity_map + + torch.randn(*self.identity_map.shape).to(image_A.device) + * 1 + / self.identity_map.shape[-1] + ) + + # inverse consistency one way + + approximate_Iepsilon1 = self.phi_AB(self.phi_BA(Iepsilon)) + + approximate_Iepsilon2 = self.phi_BA(self.phi_AB(Iepsilon)) + + inverse_consistency_loss = torch.mean( + (Iepsilon - approximate_Iepsilon1) ** 2 + ) + torch.mean((Iepsilon - approximate_Iepsilon2) ** 2) + + transform_magnitude = torch.mean( + (self.identity_map - self.phi_AB_vectorfield) ** 2 + ) + + all_loss = self.lmbda * inverse_consistency_loss + similarity_loss + + return ICONLoss( + all_loss, + inverse_consistency_loss, + similarity_loss, + transform_magnitude, + flips(self.phi_BA_vectorfield), + ) + + +class GradICON(network_wrappers.RegistrationModule): + def compute_gradient_icon_loss(self, phi_AB, phi_BA): + Iepsilon = ( + self.identity_map + + torch.randn(*self.identity_map.shape).to(self.identity_map.device) + * 1 + / self.identity_map.shape[-1] + ) + + # compute squared Frobenius of Jacobian of icon error + + direction_losses = [] + + approximate_Iepsilon = phi_AB(phi_BA(Iepsilon)) + + inverse_consistency_error = Iepsilon - approximate_Iepsilon + + delta = 0.001 + + if len(self.identity_map.shape) == 4: + dx = torch.Tensor([[[[delta]], [[0.0]]]]).to(self.identity_map.device) + dy = torch.Tensor([[[[0.0]], [[delta]]]]).to(self.identity_map.device) + direction_vectors = (dx, dy) + + elif len(self.identity_map.shape) == 5: + dx = torch.Tensor([[[[[delta]]], [[[0.0]]], [[[0.0]]]]]).to( + self.identity_map.device + ) + dy = torch.Tensor([[[[[0.0]]], [[[delta]]], [[[0.0]]]]]).to( + self.identity_map.device + ) + dz = torch.Tensor([[[[0.0]]], [[[0.0]]], [[[delta]]]]).to( + self.identity_map.device + ) + direction_vectors = (dx, dy, dz) + elif len(self.identity_map.shape) == 3: + dx = torch.Tensor([[[delta]]]).to(self.identity_map.device) + direction_vectors = (dx,) + + for d in direction_vectors: + approximate_Iepsilon_d = phi_AB(phi_BA(Iepsilon + d)) + inverse_consistency_error_d = Iepsilon + d - approximate_Iepsilon_d + grad_d_icon_error = ( + inverse_consistency_error - inverse_consistency_error_d + ) / delta + direction_losses.append(torch.mean(grad_d_icon_error**2)) + + inverse_consistency_loss = sum(direction_losses) + + return inverse_consistency_loss + + def compute_similarity_measure(self, phi_AB, phi_BA, image_A, image_B): + self.phi_AB_vectorfield = phi_AB(self.identity_map) + self.phi_BA_vectorfield = phi_BA(self.identity_map) + + if getattr(self.similarity, "isInterpolated", False): + # tag images during warping so that the similarity measure + # can use information about whether a sample is interpolated + # or extrapolated + inbounds_tag = torch.zeros([image_A.shape[0]] + [1] + list(image_A.shape[2:]), device=image_A.device) + if len(self.input_shape) - 2 == 3: + inbounds_tag[:, :, 1:-1, 1:-1, 1:-1] = 1.0 + elif len(self.input_shape) - 2 == 2: + inbounds_tag[:, :, 1:-1, 1:-1] = 1.0 + else: + inbounds_tag[:, :, 1:-1] = 1.0 + else: + inbounds_tag = None + + self.warped_image_A = self.as_function( + torch.cat([image_A, inbounds_tag], axis=1) if inbounds_tag is not None else image_A + )(self.phi_AB_vectorfield) + self.warped_image_B = self.as_function( + torch.cat([image_B, inbounds_tag], axis=1) if inbounds_tag is not None else image_B + )(self.phi_BA_vectorfield) + similarity_loss = self.similarity( + self.warped_image_A, image_B + ) + self.similarity(self.warped_image_B, image_A) + return similarity_loss + + def forward(self, image_A, image_B) -> ICONLoss: + + assert self.identity_map.shape[2:] == image_A.shape[2:] + assert self.identity_map.shape[2:] == image_B.shape[2:] + + # Tag used elsewhere for optimization. + # Must be set at beginning of forward b/c not preserved by .cuda() etc + self.identity_map.isIdentity = True + + self.phi_AB = self.regis_net(image_A, image_B) + self.phi_BA = self.regis_net(image_B, image_A) + + similarity_loss = self.compute_similarity_measure( + self.phi_AB, self.phi_BA, image_A, image_B + ) + + inverse_consistency_loss = self.compute_gradient_icon_loss( + self.phi_AB, self.phi_BA + ) + + all_loss = self.lmbda * inverse_consistency_loss + similarity_loss + + transform_magnitude = torch.mean( + (self.identity_map - self.phi_AB_vectorfield) ** 2 + ) + return ICONLoss( + all_loss, + inverse_consistency_loss, + similarity_loss, + transform_magnitude, + flips(self.phi_BA_vectorfield), + ) + + +class GradientICONSparse(network_wrappers.RegistrationModule): + def __init__(self, network, similarity, lmbda): + + super().__init__() + + self.regis_net = network + self.lmbda = lmbda + self.similarity = similarity + + def forward(self, image_A, image_B): + + assert self.identity_map.shape[2:] == image_A.shape[2:] + assert self.identity_map.shape[2:] == image_B.shape[2:] + + # Tag used elsewhere for optimization. + # Must be set at beginning of forward b/c not preserved by .cuda() etc + self.identity_map.isIdentity = True + + self.phi_AB = self.regis_net(image_A, image_B) + self.phi_BA = self.regis_net(image_B, image_A) + + self.phi_AB_vectorfield = self.phi_AB(self.identity_map) + self.phi_BA_vectorfield = self.phi_BA(self.identity_map) + + # tag images during warping so that the similarity measure + # can use information about whether a sample is interpolated + # or extrapolated + + if getattr(self.similarity, "isInterpolated", False): + # tag images during warping so that the similarity measure + # can use information about whether a sample is interpolated + # or extrapolated + inbounds_tag = torch.zeros([image_A.shape[0]] + [1] + list(image_A.shape[2:]), device=image_A.device) + if len(self.input_shape) - 2 == 3: + inbounds_tag[:, :, 1:-1, 1:-1, 1:-1] = 1.0 + elif len(self.input_shape) - 2 == 2: + inbounds_tag[:, :, 1:-1, 1:-1] = 1.0 + else: + inbounds_tag[:, :, 1:-1] = 1.0 + else: + inbounds_tag = None + + self.warped_image_A = compute_warped_image_multiNC( + torch.cat([image_A, inbounds_tag], axis=1) if inbounds_tag is not None else image_A, + self.phi_AB_vectorfield, + self.spacing, + 1, + ) + self.warped_image_B = compute_warped_image_multiNC( + torch.cat([image_B, inbounds_tag], axis=1) if inbounds_tag is not None else image_B, + self.phi_BA_vectorfield, + self.spacing, + 1, + ) + + similarity_loss = self.similarity( + self.warped_image_A, image_B + ) + self.similarity(self.warped_image_B, image_A) + + if len(self.input_shape) - 2 == 3: + Iepsilon = ( + self.identity_map + + 2 * torch.randn(*self.identity_map.shape).to(config.device) + / self.identity_map.shape[-1] + )[:, :, ::2, ::2, ::2] + elif len(self.input_shape) - 2 == 2: + Iepsilon = ( + self.identity_map + + 2 * torch.randn(*self.identity_map.shape).to(config.device) + / self.identity_map.shape[-1] + )[:, :, ::2, ::2] + + # compute squared Frobenius of Jacobian of icon error + + direction_losses = [] + + approximate_Iepsilon = self.phi_AB(self.phi_BA(Iepsilon)) + + inverse_consistency_error = Iepsilon - approximate_Iepsilon + + delta = 0.001 + + if len(self.identity_map.shape) == 4: + dx = torch.Tensor([[[[delta]], [[0.0]]]]).to(config.device) + dy = torch.Tensor([[[[0.0]], [[delta]]]]).to(config.device) + direction_vectors = (dx, dy) + + elif len(self.identity_map.shape) == 5: + dx = torch.Tensor([[[[[delta]]], [[[0.0]]], [[[0.0]]]]]).to(config.device) + dy = torch.Tensor([[[[[0.0]]], [[[delta]]], [[[0.0]]]]]).to(config.device) + dz = torch.Tensor([[[[0.0]]], [[[0.0]]], [[[delta]]]]).to(config.device) + direction_vectors = (dx, dy, dz) + elif len(self.identity_map.shape) == 3: + dx = torch.Tensor([[[delta]]]).to(config.device) + direction_vectors = (dx,) + + for d in direction_vectors: + approximate_Iepsilon_d = self.phi_AB(self.phi_BA(Iepsilon + d)) + inverse_consistency_error_d = Iepsilon + d - approximate_Iepsilon_d + grad_d_icon_error = ( + inverse_consistency_error - inverse_consistency_error_d + ) / delta + direction_losses.append(torch.mean(grad_d_icon_error**2)) + + inverse_consistency_loss = sum(direction_losses) + + all_loss = self.lmbda * inverse_consistency_loss + similarity_loss + + transform_magnitude = torch.mean( + (self.identity_map - self.phi_AB_vectorfield) ** 2 + ) + return ICONLoss( + all_loss, + inverse_consistency_loss, + similarity_loss, + transform_magnitude, + flips(self.phi_BA_vectorfield), + ) + + + +class BendingEnergy(network_wrappers.RegistrationModule): + def __init__(self, network, similarity, lmbda): + super().__init__() + + self.regis_net = network + self.lmbda = lmbda + self.similarity = similarity + + def compute_bending_energy_loss(self, phi_AB_vectorfield): + # dxdx = [f[x+h, y] + f[x-h, y] - 2 * f[x, y]]/(h**2) + # dxdy = [f[x+h, y+h] + f[x-h, y-h] - f[x+h, y-h] - f[x-h, y+h]]/(4*h**2) + # BE_2d = |dxdx| + |dydy| + 2 * |dxdy| + # psudo code: BE_2d = [torch.mean(dxdx**2) + torch.mean(dydy**2) + 2 * torch.mean(dxdy**2)]/4.0 + # BE_3d = |dxdx| + |dydy| + |dzdz| + 2 * |dxdy| + 2 * |dydz| + 2 * |dxdz| + + if len(self.identity_map.shape) == 3: + dxdx = (phi_AB_vectorfield[:, :, 2:] + - 2*phi_AB_vectorfield[:, :, 1:-1] + + phi_AB_vectorfield[:, :, :-2]) / self.spacing[0]**2 + bending_energy = torch.mean((dxdx)**2) + + elif len(self.identity_map.shape) == 4: + dxdx = (phi_AB_vectorfield[:, :, 2:] + - 2*phi_AB_vectorfield[:, :, 1:-1] + + phi_AB_vectorfield[:, :, :-2]) / self.spacing[0]**2 + dydy = (phi_AB_vectorfield[:, :, :, 2:] + - 2*phi_AB_vectorfield[:, :, :, 1:-1] + + phi_AB_vectorfield[:, :, :, :-2]) / self.spacing[1]**2 + dxdy = (phi_AB_vectorfield[:, :, 2:, 2:] + + phi_AB_vectorfield[:, :, :-2, :-2] + - phi_AB_vectorfield[:, :, 2:, :-2] + - phi_AB_vectorfield[:, :, :-2, 2:]) / (4.0*self.spacing[0]*self.spacing[1]) + bending_energy = (torch.mean(dxdx**2) + torch.mean(dydy**2) + 2*torch.mean(dxdy**2)) / 4.0 + elif len(self.identity_map.shape) == 5: + dxdx = (phi_AB_vectorfield[:, :, 2:] + - 2*phi_AB_vectorfield[:, :, 1:-1] + + phi_AB_vectorfield[:, :, :-2]) / self.spacing[0]**2 + dydy = (phi_AB_vectorfield[:, :, :, 2:] + - 2*phi_AB_vectorfield[:, :, :, 1:-1] + + phi_AB_vectorfield[:, :, :, :-2]) / self.spacing[1]**2 + dzdz = (phi_AB_vectorfield[:, :, :, :, 2:] + - 2*phi_AB_vectorfield[:, :, :, :, 1:-1] + + phi_AB_vectorfield[:, :, :, :, :-2]) / self.spacing[2]**2 + dxdy = (phi_AB_vectorfield[:, :, 2:, 2:] + + phi_AB_vectorfield[:, :, :-2, :-2] + - phi_AB_vectorfield[:, :, 2:, :-2] + - phi_AB_vectorfield[:, :, :-2, 2:]) / (4.0*self.spacing[0]*self.spacing[1]) + dydz = (phi_AB_vectorfield[:, :, :, 2:, 2:] + + phi_AB_vectorfield[:, :, :, :-2, :-2] + - phi_AB_vectorfield[:, :, :, 2:, :-2] + - phi_AB_vectorfield[:, :, :, :-2, 2:]) / (4.0*self.spacing[1]*self.spacing[2]) + dxdz = (phi_AB_vectorfield[:, :, 2:, :, 2:] + + phi_AB_vectorfield[:, :, :-2, :, :-2] + - phi_AB_vectorfield[:, :, 2:, :, :-2] + - phi_AB_vectorfield[:, :, :-2, :, 2:]) / (4.0*self.spacing[0]*self.spacing[2]) + + bending_energy = ((dxdx**2).mean() + (dydy**2).mean() + (dzdz**2).mean() + + 2.*(dxdy**2).mean() + 2.*(dydz**2).mean() + 2.*(dxdz**2).mean()) / 9.0 + + + return bending_energy + + def compute_similarity_measure(self, phi_AB_vectorfield, image_A, image_B): + + if getattr(self.similarity, "isInterpolated", False): + # tag images during warping so that the similarity measure + # can use information about whether a sample is interpolated + # or extrapolated + inbounds_tag = torch.zeros([image_A.shape[0]] + [1] + list(image_A.shape[2:]), device=image_A.device) + if len(self.input_shape) - 2 == 3: + inbounds_tag[:, :, 1:-1, 1:-1, 1:-1] = 1.0 + elif len(self.input_shape) - 2 == 2: + inbounds_tag[:, :, 1:-1, 1:-1] = 1.0 + else: + inbounds_tag[:, :, 1:-1] = 1.0 + else: + inbounds_tag = None + + self.warped_image_A = self.as_function( + torch.cat([image_A, inbounds_tag], axis=1) if inbounds_tag is not None else image_A + )(phi_AB_vectorfield) + + similarity_loss = self.similarity( + self.warped_image_A, image_B + ) + return similarity_loss + + def forward(self, image_A, image_B) -> ICONLoss: + + assert self.identity_map.shape[2:] == image_A.shape[2:] + assert self.identity_map.shape[2:] == image_B.shape[2:] + + # Tag used elsewhere for optimization. + # Must be set at beginning of forward b/c not preserved by .cuda() etc + self.identity_map.isIdentity = True + + self.phi_AB = self.regis_net(image_A, image_B) + self.phi_AB_vectorfield = self.phi_AB(self.identity_map) + + similarity_loss = 2 * self.compute_similarity_measure( + self.phi_AB_vectorfield, image_A, image_B + ) + + bending_energy_loss = self.compute_bending_energy_loss( + self.phi_AB_vectorfield + ) + + all_loss = self.lmbda * bending_energy_loss + similarity_loss + + transform_magnitude = torch.mean( + (self.identity_map - self.phi_AB_vectorfield) ** 2 + ) + return BendingLoss( + all_loss, + bending_energy_loss, + similarity_loss, + transform_magnitude, + flips(self.phi_AB_vectorfield), + ) + + def prepare_for_viz(self, image_A, image_B): + self.phi_AB = self.regis_net(image_A, image_B) + self.phi_AB_vectorfield = self.phi_AB(self.identity_map) + self.phi_BA = self.regis_net(image_B, image_A) + self.phi_BA_vectorfield = self.phi_BA(self.identity_map) + + self.warped_image_A = self.as_function(image_A)(self.phi_AB_vectorfield) + self.warped_image_B = self.as_function(image_B)(self.phi_BA_vectorfield) + +class Diffusion(BendingEnergyNet): + def compute_bending_energy_loss(self, phi_AB_vectorfield): + phi_AB_vectorfield = self.identity_map - phi_AB_vectorfield + if len(self.identity_map.shape) == 3: + bending_energy = torch.mean(( + - phi_AB_vectorfield[:, :, 1:] + + phi_AB_vectorfield[:, :, 1:-1] + )**2) + + elif len(self.identity_map.shape) == 4: + bending_energy = torch.mean(( + - phi_AB_vectorfield[:, :, 1:] + + phi_AB_vectorfield[:, :, :-1] + )**2) + torch.mean(( + - phi_AB_vectorfield[:, :, :, 1:] + + phi_AB_vectorfield[:, :, :, :-1] + )**2) + elif len(self.identity_map.shape) == 5: + bending_energy = torch.mean(( + - phi_AB_vectorfield[:, :, 1:] + + phi_AB_vectorfield[:, :, :-1] + )**2) + torch.mean(( + - phi_AB_vectorfield[:, :, :, 1:] + + phi_AB_vectorfield[:, :, :, :-1] + )**2) + torch.mean(( + - phi_AB_vectorfield[:, :, :, :, 1:] + + phi_AB_vectorfield[:, :, :, :, :-1] + )**2) + + + return bending_energy * self.identity_map.shape[2] **2 + diff --git a/ICON/icon_registration/network_wrappers.py b/ICON/icon_registration/network_wrappers.py new file mode 100644 index 00000000..7ed1d2a9 --- /dev/null +++ b/ICON/icon_registration/network_wrappers.py @@ -0,0 +1,238 @@ +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + + +class DisplacementField(RegistrationModule): + """ + Wrap an inner neural network 'net' that returns a tensor of displacements + [B x N x H x W (x D)], into a RegistrationModule that returns a function that + transforms a tensor of coordinates + """ + + def __init__(self, net): + super().__init__() + self.net = net + + def forward(self, image_A, image_B): + concatenated_images = torch.cat([image_A, image_B], axis=1) + tensor_of_displacements = self.net(concatenated_images) + displacement_field = self.as_function(tensor_of_displacements) + + def transform(coordinates): + return coordinates + displacement_field(coordinates) + + return {"phi_AB": transform} + +class VelocityField(RegistrationModule): + def __init__(self, net): + super().__init__() + self.net = net + self.n_steps = 256 + + def forward(self, image_A, image_B): + concatenated_images = torch.cat([image_A, image_B], axis=1) + velocity_field = self.net(concatenated_images) + velocityfield_delta = velocity_field / self.n_steps + + for _ in range(8): + velocityfield_delta = velocityfield_delta + self.as_function( + velocityfield_delta)(velocityfield_delta + self.identity_map) + def transform(coordinate_tensor): + coordinate_tensor = coordinate_tensor + self.as_function(velocityfield_delta)(coordinate_tensor) + return coordinate_tensor + return {"phi_AB": transform, "velocity_fields": [velocity_field]} + + +def multiply_matrix_vectorfield(matrix, vectorfield): + dimension = len(vectorfield.shape) - 2 + if dimension == 2: + batch_matrix_multiply = "ijkl,imj->imkl" + else: + batch_matrix_multiply = "ijkln,imj->imkln" + return torch.einsum(batch_matrix_multiply, vectorfield, matrix) + + +class Affine(RegistrationModule): + """ + wrap an inner neural network `net` that returns an N x N+1 matrix representing + an affine transform, into a RegistrationModule that returns a function that + transforms a tensor of coordinates. + """ + + def __init__(self, net): + super().__init__() + self.net = net + + def forward(self, image_A, image_B): + matrix_phi = self.net(image_A, image_B) + + def transform(tensor_of_coordinates): + shape = list(tensor_of_coordinates.shape) + shape[1] = 1 + coordinates_homogeneous = torch.cat( + [tensor_of_coordinates, torch.ones(shape, device=tensor_of_coordinates.device)], axis=1 + ) + return multiply_matrix_vectorfield(matrix_phi, coordinates_homogeneous)[:, :-1] + + return {"phi_AB": transform} + + +class TwoStep(RegistrationModule): + """Combine two RegistrationModules. + + First netPhi is called on the input images, then image_A is warped with + the resulting field, and then netPsi is called on warped A and image_B + in order to find a residual warping. Finally, the composition of the two + transforms is returned. + """ + + def __init__(self, netPhi, netPsi): + super().__init__() + self.netPhi = netPhi + self.netPsi = netPsi + + def forward(self, image_A, image_B): + phi = self.netPhi(image_A, image_B) + psi = self.netPsi( + self.as_function(image_A)(phi(self.identity_map)), + image_B + )["phi_AB"] + result = {"phi_AB": lambda tensor_of_coordinates: phi["phi_AB"](psi["phi_AB"](tensor_of_coordinates))} + + regularization_loss = 0 + if "regularization_loss" in phi: + regularization_loss += phi["regularization_loss"] + if "regularization_loss" in psi: + regularization_loss += psi["regularization_loss"] + if "regularization_loss" in phi or "regularization_loss" in psi: + result["regularization_loss"] = regularization_loss + + velocity_fields = [] + if "velocity_fields" in phi: + velocity_fields += phi["regularization_loss"] + if "velocity_fields" in psi: + velocity_fields += psi["regularization_loss"] + if "velocity_fields" in phi or "regularization_loss" in psi: + result["velocity_fields"] = regularization_loss + return result + +class Downsample(RegistrationModule): + """ + Perform registration using the wrapped RegistrationModule `net` + at half input resolution. + """ + + def __init__(self, net, dimension): + super().__init__() + self.net = net + if dimension == 2: + self.avg_pool = F.avg_pool2d + self.interpolate_mode = "bilinear" + else: + self.avg_pool = F.avg_pool3d + self.interpolate_mode = "trilinear" + self.dimension = dimension + # This member variable is read by assign_identity_map when + # walking the network tree and assigning identity_maps + # to know that all children of this module operate at a lower + # resolution. + self.downscale_factor = 2 + + def forward(self, image_A, image_B): + + image_A = self.avg_pool(image_A, 2, ceil_mode=True) + image_B = self.avg_pool(image_B, 2, ceil_mode=True) + return self.net(image_A, image_B) + + +class InverseConsistentVelocityField(RegistrationModule): + def __init__(self, net): + super().__init__() + self.net = net + self.n_steps = 7 + + def forward(self, image_A, image_B): + concatenated_images_AB = torch.cat([image_A, image_B], axis=1) + concatenated_images_BA = torch.cat([image_B, image_A], axis=1) + velocity_field = self.net(concatenated_images_AB) - self.net(concatenated_images_BA) + velocityfield_delta_ab = velocity_field / 2**self.n_steps + velocityfield_delta_ba = -velocityfield_delta_ab + + for _ in range(self.n_steps): + velocityfield_delta_ab = velocityfield_delta_ab + self.as_function( + velocityfield_delta_ab + )(velocityfield_delta_ab + self.identity_map) + + def transform_AB(coordinate_tensor): + coordinate_tensor = coordinate_tensor + self.as_function( + velocityfield_delta_ab + )(coordinate_tensor) + return coordinate_tensor + + for _ in range(self.n_steps): + velocityfield_delta_ba = velocityfield_delta_ba + self.as_function( + velocityfield_delta_ba + )(velocityfield_delta_ba + self.identity_map) + + def transform_BA(coordinate_tensor): + coordinate_tensor = coordinate_tensor + self.as_function( + velocityfield_delta_ba + )(coordinate_tensor) + return coordinate_tensor + + + return {"phi_AB": transform_AB, "phi_BA": transform_BA, "velocity_fields":[velocity_field]} + +class InverseConsistentAffine(RegistrationModule): + """ + wrap an inner neural network `net` that returns an Batch x N*N+1 tensor representing + an affine transform, into a RegistrationModule that returns a function that + transforms a tensor of coordinates. + """ + + def __init__(self, net): + super().__init__() + self.net = net + + def forward(self, image_A, image_B): + concatenated_images_AB = torch.cat([image_A, image_B], axis=1) + concatenated_images_BA = torch.cat([image_B, image_A], axis=1) + matrix_phi = self.net(concatenated_images_AB) - self.net(concatenated_images_BA) + + matrix_phi = matrix_phi.reshape(image_A.shape[0], len(image_A.shape), len(image_A.shape) + 1) + + + matrix_phi_AB = torch.linalg.matrix_exp(matrix_phi) + matrix_phi_BA = torch.linalg.matrix_exp(-matrix_phi) + + def transform_AB(tensor_of_coordinates): + shape = list(tensor_of_coordinates.shape) + shape[1] = 1 + coordinates_homogeneous = torch.cat( + [ + tensor_of_coordinates, + torch.ones(shape, device=tensor_of_coordinates.device), + ], + axis=1, + ) + return imultiply_matrix_vectorfield( + matrix_phi, coordinates_homogeneous + )[:, :-1] + + def transform_BA(tensor_of_coordinates): + shape = list(tensor_of_coordinates.shape) + shape[1] = 1 + coordinates_homogeneous = torch.cat( + [ + tensor_of_coordinates, + torch.ones(shape, device=tensor_of_coordinates.device), + ], + axis=1, + ) + return multiply_matrix_vectorfield( + matrix_phi_BA, coordinates_homogeneous + )[:, :-1] + + return {"phi_AB": transform_AB, "phi_BA": transform_BA} diff --git a/ICON/icon_registration/networks.py b/ICON/icon_registration/networks.py new file mode 100644 index 00000000..d0d1faf9 --- /dev/null +++ b/ICON/icon_registration/networks.py @@ -0,0 +1,839 @@ +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from icon_registration import config + + +class ConvNet(nn.Module): + def __init__(self, dimension=2, output_dim=100): + super().__init__() + self.dimension = dimension + + if dimension == 2: + self.Conv = nn.Conv2d + self.avg_pool = F.avg_pool2d + else: + self.Conv = nn.Conv3d + self.avg_pool = F.avg_pool3d + + self.features = [2, 16, 32, 64, 128, 128, 256] + self.convs = nn.ModuleList([]) + for depth in range(len(self.features) - 1): + self.convs.append( + self.Conv( + self.features[depth], + self.features[depth + 1], + kernel_size=3, + padding=1, + ) + ) + self.dense2 = nn.Linear(256, 300) + self.dense3 = nn.Linear(300, output_dim) + + def forward(self, x, y): + x = torch.cat([x, y], 1) + for depth in range(len(self.features) - 1): + x = F.relu(x) + x = self.convs[depth](x) + x = self.avg_pool(x, 2, ceil_mode=True) + x = self.avg_pool(x, x.shape[2:], ceil_mode=True) + x = torch.reshape(x, (-1, 256)) + x = F.relu(self.dense2(x)) + x = self.dense3(x) + return x + + +class Autoencoder(nn.Module): + def __init__(self, num_layers, channels): + super().__init__() + self.num_layers = num_layers + down_channels = channels[0] + up_channels = channels[1] + self.downConvs = nn.ModuleList([]) + self.upConvs = nn.ModuleList([]) + for depth in range(self.num_layers): + self.downConvs.append( + nn.Conv2d( + down_channels[depth], + down_channels[depth + 1], + kernel_size=3, + padding=1, + stride=2, + ) + ) + self.upConvs.append( + nn.ConvTranspose2d( + up_channels[depth + 1], + up_channels[depth], + kernel_size=4, + padding=1, + stride=2, + ) + ) + self.lastConv = nn.Conv2d(16, 2, kernel_size=3, padding=1) + torch.nn.init.zeros_(self.lastConv.weight) + + def forward(self, x, y): + x = torch.cat([x, y], 1) + skips = [] + for depth in range(self.num_layers): + skips.append(x) + x = F.relu(self.downConvs[depth](x)) + for depth in reversed(range(self.num_layers)): + x = F.relu(self.upConvs[depth](x)) + x = x[:, :, : skips[depth].size()[2], : skips[depth].size()[3]] + x = self.lastConv(x) + return x / 10 + + +def tallAE(): + return Autoencoder( + 5, + np.array( + [ + [2, 16, 32, 64, 256, 512], + [16, 32, 64, 128, 256, 512], + ] + ), + ) + + +class Residual(nn.Module): + def __init__(self, features): + super().__init__() + self.bn1 = nn.BatchNorm2d(num_features=features) + self.bn2 = nn.BatchNorm2d(num_features=features) + + self.conv1 = nn.Conv2d(features, features, kernel_size=3, padding=1) + self.conv2 = nn.Conv2d(features, features, kernel_size=3, padding=1) + + def forward(self, x): + y = F.relu(self.bn1(x)) + y = self.conv1(y) + y = F.relu(self.bn2(y)) + y = self.conv2(y) + return y + x + + +class UNet(nn.Module): + def __init__(self, num_layers, channels, dimension): + super().__init__() + + if dimension == 2: + self.BatchNorm = nn.BatchNorm2d + self.Conv = nn.Conv2d + self.ConvTranspose = nn.ConvTranspose2d + else: + self.BatchNorm = nn.BatchNorm3d + self.Conv = nn.Conv3d + self.ConvTranspose = nn.ConvTranspose3d + self.num_layers = num_layers + down_channels = np.array(channels[0]) + up_channels_out = np.array(channels[1]) + up_channels_in = down_channels[1:] + np.concatenate([up_channels_out[1:], [0]]) + self.downConvs = nn.ModuleList([]) + self.upConvs = nn.ModuleList([]) + # self.residues = nn.ModuleList([]) + self.batchNorms = nn.ModuleList( + [ + self.BatchNorm(num_features=up_channels_out[_]) + for _ in range(self.num_layers) + ] + ) + for depth in range(self.num_layers): + self.downConvs.append( + self.Conv( + down_channels[depth], + down_channels[depth + 1], + kernel_size=3, + padding=1, + stride=2, + ) + ) + self.upConvs.append( + self.ConvTranspose( + up_channels_in[depth], + up_channels_out[depth], + kernel_size=4, + padding=1, + stride=2, + ) + ) + # self.residues.append( + # Residual(up_channels_out[depth]) + # ) + self.lastConv = self.Conv(18, dimension, kernel_size=3, padding=1) + torch.nn.init.zeros_(self.lastConv.weight) + + def forward(self, x, y): + x = torch.cat([x, y], 1) + skips = [] + for depth in range(self.num_layers): + skips.append(x) + x = F.relu(self.downConvs[depth](x)) + for depth in reversed(range(self.num_layers)): + x = F.relu(self.upConvs[depth](x)) + x = self.batchNorms[depth](x) + + x = x[:, :, : skips[depth].size()[2], : skips[depth].size()[3]] + x = torch.cat([x, skips[depth]], 1) + x = self.lastConv(x) + return x / 10 + + +def pad_or_crop(x, shape, dimension): + y = x[:, : shape[1]] + if x.size()[1] < shape[1]: + if dimension == 3: + y = F.pad(y, (0, 0, 0, 0, 0, 0, shape[1] - x.size()[1], 0)) + else: + y = F.pad(y, (0, 0, 0, 0, shape[1] - x.size()[1], 0)) + assert y.size()[1] == shape[1] + + return y + + +class UNet2(nn.Module): + def __init__(self, num_layers, channels, dimension): + super().__init__() + self.dimension = dimension + if dimension == 2: + self.BatchNorm = nn.BatchNorm2d + self.Conv = nn.Conv2d + self.ConvTranspose = nn.ConvTranspose2d + self.avg_pool = F.avg_pool2d + self.interpolate_mode = "bilinear" + else: + self.BatchNorm = nn.BatchNorm3d + self.Conv = nn.Conv3d + self.ConvTranspose = nn.ConvTranspose3d + self.avg_pool = F.avg_pool3d + self.interpolate_mode = "trilinear" + self.num_layers = num_layers + down_channels = np.array(channels[0]) + up_channels_out = np.array(channels[1]) + up_channels_in = down_channels[1:] + np.concatenate([up_channels_out[1:], [0]]) + self.downConvs = nn.ModuleList([]) + self.upConvs = nn.ModuleList([]) + # self.residues = nn.ModuleList([]) + self.batchNorms = nn.ModuleList( + [ + self.BatchNorm(num_features=up_channels_out[_]) + for _ in range(self.num_layers) + ] + ) + for depth in range(self.num_layers): + self.downConvs.append( + self.Conv( + down_channels[depth], + down_channels[depth + 1], + kernel_size=3, + padding=1, + stride=2, + ) + ) + self.upConvs.append( + self.ConvTranspose( + up_channels_in[depth], + up_channels_out[depth], + kernel_size=4, + padding=1, + stride=2, + ) + ) + # self.residues.append( + # Residual(up_channels_out[depth]) + # ) + self.lastConv = self.Conv( + down_channels[0] + up_channels_out[0], dimension, kernel_size=3, padding=1 + ) + torch.nn.init.zeros_(self.lastConv.weight) + torch.nn.init.zeros_(self.lastConv.bias) + + def forward(self, x, y): + x = torch.cat([x, y], 1) + skips = [] + for depth in range(self.num_layers): + skips.append(x) + y = self.downConvs[depth](F.leaky_relu(x)) + x = y + pad_or_crop( + self.avg_pool(x, 2, ceil_mode=True), y.size(), self.dimension + ) + + for depth in reversed(range(self.num_layers)): + y = self.upConvs[depth](F.leaky_relu(x)) + x = y + F.interpolate( + pad_or_crop(x, y.size(), self.dimension), + scale_factor=2, + mode=self.interpolate_mode, + align_corners=False, + ) + # x = self.residues[depth](x) + x = self.batchNorms[depth](x) + if self.dimension == 2: + x = x[:, :, : skips[depth].size()[2], : skips[depth].size()[3]] + else: + x = x[ + :, + :, + : skips[depth].size()[2], + : skips[depth].size()[3], + : skips[depth].size()[4], + ] + x = torch.cat([x, skips[depth]], 1) + x = self.lastConv(x) + return x / 10 + + +class UNet2ChunkyMiddle(nn.Module): + def __init__(self, num_layers, channels, dimension): + super().__init__() + self.dimension = dimension + if dimension == 2: + self.BatchNorm = nn.BatchNorm2d + self.Conv = nn.Conv2d + self.ConvTranspose = nn.ConvTranspose2d + self.avg_pool = F.avg_pool2d + self.interpolate_mode = "bilinear" + else: + self.BatchNorm = nn.BatchNorm3d + self.Conv = nn.Conv3d + self.ConvTranspose = nn.ConvTranspose3d + self.avg_pool = F.avg_pool3d + self.interpolate_mode = "trilinear" + self.num_layers = num_layers + down_channels = np.array(channels[0]) + up_channels_out = np.array(channels[1]) + up_channels_in = down_channels[1:] + np.concatenate([up_channels_out[1:], [0]]) + self.downConvs = nn.ModuleList([]) + self.upConvs = nn.ModuleList([]) + # self.residues = nn.ModuleList([]) + self.batchNorms = nn.ModuleList( + [ + self.BatchNorm(num_features=up_channels_out[_]) + for _ in range(self.num_layers) + ] + ) + for depth in range(self.num_layers): + self.downConvs.append( + self.Conv( + down_channels[depth], + down_channels[depth + 1], + kernel_size=3, + padding=1, + stride=2, + ) + ) + self.upConvs.append( + self.ConvTranspose( + up_channels_in[depth], + up_channels_out[depth], + kernel_size=4, + padding=1, + stride=2, + ) + ) + # self.residues.append( + # Residual(up_channels_out[depth]) + # ) + self.lastConv = self.Conv(18, dimension, kernel_size=3, padding=1) + torch.nn.init.zeros_(self.lastConv.weight) + + self.middle_dense = nn.ModuleList( + [ + torch.nn.Linear(512 * 2 * 3 * 3, 128 * 2 * 3 * 3), + torch.nn.Linear(128 * 2 * 3 * 3, 512 * 2 * 3 * 3), + ] + ) + + def forward(self, x, y): + x = torch.cat([x, y], 1) + skips = [] + for depth in range(self.num_layers): + skips.append(x) + y = self.downConvs[depth](F.leaky_relu(x)) + x = y + pad_or_crop( + self.avg_pool(x, 2, ceil_mode=True), y.size(), self.dimension + ) + y = F.layer_norm + + x = torch.reshape(x, (-1, 512 * 2 * 3 * 3)) + x = self.middle_dense[1](F.leaky_relu(self.middle_dense[0](x))) + x = torch.reshape(x, (-1, 512, 2, 3, 3)) + for depth in reversed(range(self.num_layers)): + y = self.upConvs[depth](F.leaky_relu(x)) + x = y + F.interpolate( + pad_or_crop(x, y.size(), self.dimension), + scale_factor=2, + mode=self.interpolate_mode, + align_corners=False, + ) + # x = self.residues[depth](x) + x = self.batchNorms[depth](x) + + x = x[:, :, : skips[depth].size()[2], : skips[depth].size()[3]] + x = torch.cat([x, skips[depth]], 1) + x = self.lastConv(x) + return x / 10 + + +class UNet3(nn.Module): + def __init__(self, num_layers, channels, dimension, normalization): + super().__init__() + + self.dimension = dimension + if dimension == 2: + self.BatchNorm = nn.BatchNorm2d + self.Conv = nn.Conv2d + self.ConvTranspose = nn.ConvTranspose2d + self.avg_pool = F.avg_pool2d + self.interpolate_mode = "bilinear" + else: + self.BatchNorm = nn.BatchNorm3d + self.Conv = nn.Conv3d + self.ConvTranspose = nn.ConvTranspose3d + self.avg_pool = F.avg_pool3d + self.interpolate_mode = "trilinear" + self.num_layers = num_layers + down_channels = np.array(channels[0]) + up_channels_out = np.array(channels[1]) + up_channels_in = down_channels[1:] + np.concatenate([up_channels_out[1:], [0]]) + self.downConvs = nn.ModuleList([]) + self.upConvs = nn.ModuleList([]) + + # More traditional residual structure + # self.down_1x1s = nn.ModuleList([]) + # self.up_1x1s = nn.ModuleList([]) + + # self.residues = nn.ModuleList([]) + self.normalization = normalization + if self.normalization == "batchnorm": + self.batchNorms = nn.ModuleList( + [ + self.BatchNorm(num_features=up_channels_out[_]) + for _ in range(self.num_layers) + ] + ) + if self.normalization == "groupnorm": + self.groupNorms = nn.ModuleList( + [ + nn.GroupNorm( + max(16, up_channels_out[depth]), up_channels_out[depth] + ) + for depth in range(self.num_layers) + ] + ) + for depth in range(self.num_layers): + self.downConvs.append( + self.Conv( + down_channels[depth], + down_channels[depth + 1], + kernel_size=3, + padding=1, + stride=2, + ) + ) + # self.down_1x1s.append( + # self.Conv( + # down_channels[depth + 1], + # down_channels[depth + 1], + # kernel_size=3, + # padding=1, + # stride=1, + # ) + # ) + self.upConvs.append( + self.ConvTranspose( + up_channels_in[depth], + up_channels_out[depth], + kernel_size=4, + padding=1, + stride=2, + ) + ) + # self.up_1x1s.append( + # self.Conv( + # up_channels_out[depth], + # up_channels_out[depth], + # kernel_size=3, + # padding=1, + # stride=1, + # ) + # ) + + # self.residues.append( + # Residual(up_channels_out[depth]) + # ) + self.lastConv = self.Conv(18, dimension, kernel_size=3, padding=1) + torch.nn.init.zeros_(self.lastConv.weight) + + def forward(self, x, y): + x = torch.cat([x, y], 1) + skips = [] + for depth in range(self.num_layers): + skips.append(x) + y = self.downConvs[depth](F.leaky_relu(x)) + # y = self.down_1x1s[depth](F.leaky_relu(y)) + x = y + pad_or_crop( + self.avg_pool(x, 2, ceil_mode=True), y.size(), self.dimension + ) + y = F.layer_norm + + for depth in reversed(range(self.num_layers)): + y = self.upConvs[depth](F.leaky_relu(x)) + # y = self.up_1x1s[depth](F.leaky_relu(y)) + x = y + F.interpolate( + pad_or_crop(x, y.size(), self.dimension), + scale_factor=2, + mode=self.interpolate_mode, + align_corners=False, + ) + # x = self.residues[depth](x) + if self.normalization == "batchnorm": + x = self.batchNorms[depth](x) + + if self.normalization == "groupnorm": + x = self.groupNorms[depth](x) + x = x[:, :, : skips[depth].size()[2], : skips[depth].size()[3]] + x = torch.cat([x, skips[depth]], 1) + x = self.lastConv(x) + return x / 10 + + +def tallUNet(unet=UNet, dimension=2): + return unet( + 5, + [[2, 16, 32, 64, 256, 512], [16, 32, 64, 128, 256]], + dimension, + ) + + +def tallishUNet2(dimension=2): + return UNet2( + 6, + [[2, 16, 32, 64, 256, 512, 512], [16, 32, 64, 128, 256, 512]], + dimension, + ) + + +def tallerUNet2(dimension=2): + return UNet2( + 7, + [[2, 16, 32, 64, 256, 512, 512, 512], [16, 32, 64, 128, 256, 512, 512]], + dimension, + ) + + +def tallUNet2(dimension=2, input_channels=1): + return UNet2( + 5, + [[input_channels*2, 16, 32, 64, 256, 512], [16, 32, 64, 128, 256]], + dimension, + ) + + +def tallUNet3(normalization="batchnorm", dimension=2): + return UNet3( + 5, + [[2, 16, 32, 64, 256, 512], [16, 32, 64, 128, 256]], + dimension, + normalization=normalization, + ) + + +class RegisNet(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(2, 10, kernel_size=5, padding=2) + self.conv2 = nn.Conv2d(12, 10, kernel_size=5, padding=2) + self.conv3 = nn.Conv2d(22, 10, kernel_size=5, padding=2) + self.conv4 = nn.Conv2d(32, 10, kernel_size=5, padding=2) + self.conv5 = nn.Conv2d(42, 10, kernel_size=5, padding=2) + self.conv6 = nn.Conv2d(52, 2, kernel_size=5, padding=2) + + def forward(self, x, y): + x = torch.cat([x, y], 1) + + x = torch.cat([x, F.relu(self.conv1(x))], 1) + x = torch.cat([x, F.relu(self.conv2(x))], 1) + x = torch.cat([x, F.relu(self.conv3(x))], 1) + x = torch.cat([x, F.relu(self.conv4(x))], 1) + x = torch.cat([x, F.relu(self.conv5(x))], 1) + + return self.conv6(x) + + +class FCNet1D(nn.Module): + def __init__(self, size=28): + super().__init__() + self.size = size + self.dense1 = nn.Linear(size * 2, 8000) + self.dense2 = nn.Linear(8000, 3000) + self.dense3 = nn.Linear(3000, size) + torch.nn.init.zeros_(self.dense3.weight) + + def forward(self, x, y): + x = torch.reshape(torch.cat([x, y], 1), (-1, 2 * self.size)) + x = F.relu(self.dense1(x)) + x = F.relu(self.dense2(x)) + x = self.dense3(x) + x = torch.reshape(x, (-1, 1, self.size)) + return x + + +class FCNet(nn.Module): + def __init__(self, size=28): + super().__init__() + self.size = size + self.dense1 = nn.Linear(size * size * 2, 8000) + self.dense2 = nn.Linear(8000, 3000) + self.dense3 = nn.Linear(3000, size * size * 2) + torch.nn.init.zeros_(self.dense3.weight) + + def forward(self, x, y): + x = torch.reshape(torch.cat([x, y], 1), (-1, 2 * self.size * self.size)) + x = F.relu(self.dense1(x)) + x = F.relu(self.dense2(x)) + x = self.dense3(x) + x = torch.reshape(x, (-1, 2, self.size, self.size)) + return x + + +class FCNet3D(nn.Module): + def __init__(self, shape, bottleneck=128): + super().__init__() + self.shape = shape.copy() + self.shape[1] = 3 + self.bottleneck = bottleneck + self.dense1 = nn.Linear(2 * np.product(self.shape[2:]), self.bottleneck) + self.dense2 = nn.Linear(self.bottleneck, 8000) + self.dense3 = nn.Linear(8000, self.bottleneck) + self.dense4 = nn.Linear(self.bottleneck, np.product(self.shape[1:])) + torch.nn.init.zeros_(self.dense4.weight) + torch.nn.init.zeros_(self.dense4.bias) + + def forward(self, x, y): + x = torch.reshape(torch.cat([x, y], 1), (-1, 2 * np.product(self.shape[2:]))) + x = F.relu(self.dense1(x)) + x = F.relu(self.dense2(x)) + x = F.relu(self.dense3(x)) + x = self.dense4(x) + x = torch.reshape(x, tuple(self.shape)) + return x + + +class DenseMatrixNet(nn.Module): + def __init__(self, size=28, dimension=2): + super().__init__() + self.dimension = dimension + self.size = size + self.dense1 = nn.Linear(size * size * 2, 800) + self.dense2 = nn.Linear(800, 300) + self.dense3 = nn.Linear(300, 6 if self.dimension == 2 else 12) + torch.nn.init.zeros_(self.dense3.weight) + + def forward(self, x, y): + x = torch.reshape(torch.cat([x, y], 1), (-1, 2 * self.size * self.size)) + x = F.relu(self.dense1(x)) + x = F.relu(self.dense2(x)) + x = self.dense3(x) + if self.dimension == 3: + x = torch.reshape(x, (-1, 3, 4)) + x = torch.cat( + [ + x, + torch.Tensor([[[0, 0, 0, 1]]]) + .to(x.device) + .expand(x.shape[0], -1, -1), + ], + 1, + ) + x = x + torch.Tensor( + [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 0]] + ).to(x.device) + elif self.dimension == 2: + x = torch.reshape(x, (-1, 2, 3)) + x = torch.cat( + [ + x, + torch.Tensor([[[0, 0, 1]]]).to(x.device).expand(x.shape[0], -1, -1), + ], + 1, + ) + x = x + torch.Tensor([[1, 0, 0], [0, 1, 0], [0, 0, 0]]).to(x.device) + else: + raise ArgumentError() + return x + + +class ConvolutionalMatrixNet(nn.Module): + def __init__(self, dimension=2): + super().__init__() + self.dimension = dimension + + if dimension == 2: + self.Conv = nn.Conv2d + self.avg_pool = F.avg_pool2d + else: + self.Conv = nn.Conv3d + self.avg_pool = F.avg_pool3d + + self.features = [2, 16, 32, 64, 128, 256, 512] + self.convs = nn.ModuleList([]) + for depth in range(len(self.features) - 1): + self.convs.append( + self.Conv( + self.features[depth], + self.features[depth + 1], + kernel_size=3, + padding=1, + ) + ) + self.dense2 = nn.Linear(512, 300) + self.dense3 = nn.Linear(300, 6 if self.dimension == 2 else 12) + torch.nn.init.zeros_(self.dense3.weight) + torch.nn.init.zeros_(self.dense3.bias) + + def forward(self, x, y): + x = torch.cat([x, y], 1) + for depth in range(len(self.features) - 1): + x = F.relu(x) + x = self.convs[depth](x) + x = self.avg_pool(x, 2, ceil_mode=True) + x = self.avg_pool(x, x.shape[2:], ceil_mode=True) + x = torch.reshape(x, (-1, 512)) + x = F.relu(self.dense2(x)) + x = self.dense3(x) + if self.dimension == 3: + x = torch.reshape(x, (-1, 3, 4)) + x = torch.cat( + [ + x, + torch.Tensor([[[0, 0, 0, 1]]]) + .to(x.device) + .expand(x.shape[0], -1, -1), + ], + 1, + ) + x = x + torch.Tensor( + [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 0]] + ).to(x.device) + x = torch.matmul( + torch.Tensor( + [[1, 0, 0, 0.5], [0, 1, 0, 0.5], [0, 0, 1, 0.5], [0, 0, 0, 1]] + ).to(x.device), + x, + ) + x = torch.matmul( + x, + torch.Tensor( + [[1, 0, 0, -0.5], [0, 1, 0, -0.5], [0, 0, 1, -0.5], [0, 0, 0, 1]] + ).to(x.device), + ) + elif self.dimension == 2: + x = torch.reshape(x, (-1, 2, 3)) + x = torch.cat( + [ + x, + torch.Tensor([[[0, 0, 1]]]).to(x.device).expand(x.shape[0], -1, -1), + ], + 1, + ) + x = x + torch.Tensor([[1, 0, 0], [0, 1, 0], [0, 0, 0]]).to(x.device) + x = torch.matmul( + torch.Tensor([[1, 0, 0.5], [0, 1, 0.5], [0, 0, 1]]).to(x.device), x + ) + x = torch.matmul( + x, + torch.Tensor([[1, 0, -0.5], [0, 1, -0.5], [0, 0, 1]]).to(x.device), + ) + else: + raise ArgumentError() + return x + + +class StumpyConvolutionalMatrixNet(nn.Module): + def __init__(self, dimension=2): + super().__init__() + self.dimension = dimension + + if dimension == 2: + self.Conv = nn.Conv2d + self.avg_pool = F.avg_pool2d + else: + self.Conv = nn.Conv3d + self.avg_pool = F.avg_pool3d + + self.features = [2, 16, 32, 64, 128, 256] + self.convs = nn.ModuleList([]) + for depth in range(len(self.features) - 1): + self.convs.append( + self.Conv( + self.features[depth], + self.features[depth + 1], + kernel_size=3, + padding=1, + ) + ) + self.dense2 = nn.Linear(256 * 2 * 3 * 3, 3000) + self.dense3 = nn.Linear(3000, 6 if self.dimension == 2 else 12) + torch.nn.init.zeros_(self.dense3.weight) + torch.nn.init.zeros_(self.dense3.bias) + + def forward(self, x, y): + x = torch.cat([x, y], 1) + for depth in range(len(self.features) - 1): + x = F.relu(x) + x = self.convs[depth](x) + x = self.avg_pool(x, 2, ceil_mode=True) + x = torch.reshape(x, (-1, 256 * 2 * 3 * 3)) + x = F.relu(self.dense2(x)) + x = self.dense3(x) + if self.dimension == 3: + x = torch.reshape(x, (-1, 3, 4)) + x = torch.cat( + [ + x, + torch.Tensor([[[0, 0, 0, 1]]]) + .to(x.device) + .expand(x.shape[0], -1, -1), + ], + 1, + ) + x = x + torch.Tensor( + [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 0]] + ).to(x.device) + x = torch.matmul( + torch.Tensor( + [[1, 0, 0, 0.5], [0, 1, 0, 0.5], [0, 0, 1, 0.5], [0, 0, 0, 1]] + ).to(x.device), + x, + ) + x = torch.matmul( + x, + torch.Tensor( + [[1, 0, 0, -0.5], [0, 1, 0, -0.5], [0, 0, 1, -0.5], [0, 0, 0, 1]] + ).to(x.device), + ) + elif self.dimension == 2: + x = torch.reshape(x, (-1, 2, 3)) + x = torch.cat( + [ + x, + torch.Tensor([[[0, 0, 1]]]).to(x.device).expand(x.shape[0], -1, -1), + ], + 1, + ) + x = x + torch.Tensor([[1, 0, 0], [0, 1, 0], [0, 0, 0]]).to(x.device) + x = torch.matmul( + torch.Tensor([[1, 0, 0.5], [0, 1, 0.5], [0, 0, 1]]).to(x.device), x + ) + x = torch.matmul( + x, + torch.Tensor([[1, 0, -0.5], [0, 1, -0.5], [0, 0, 1]]).to(x.device), + ) + else: + raise ArgumentError() + return x diff --git a/ICON/icon_registration/pretrained_models/HCP_brain.py b/ICON/icon_registration/pretrained_models/HCP_brain.py new file mode 100644 index 00000000..4d5fcff6 --- /dev/null +++ b/ICON/icon_registration/pretrained_models/HCP_brain.py @@ -0,0 +1,15 @@ +import itk +from .lung_ct import init_network + +def brain_network_preprocess(image: "itk.Image") -> "itk.Image": + if type(image) == itk.Image[itk.SS, 3] : + cast_filter = itk.CastImageFilter[itk.Image[itk.SS, 3], itk.Image[itk.F, 3]].New() + cast_filter.SetInput(image) + cast_filter.Update() + image = cast_filter.GetOutput() + _, max_ = itk.image_intensity_min_max(image) + image = itk.shift_scale_image_filter(image, shift=0., scale = .9 / max_) + return image + +def brain_registration_model(pretrained=True): + return init_network("brain", pretrained=pretrained) diff --git a/ICON/icon_registration/pretrained_models/OAI_knees.py b/ICON/icon_registration/pretrained_models/OAI_knees.py new file mode 100644 index 00000000..83cbab3c --- /dev/null +++ b/ICON/icon_registration/pretrained_models/OAI_knees.py @@ -0,0 +1,67 @@ +from os.path import exists + +import torch + +import icon_registration +from icon_registration import config + +from .. import networks +from .lung_ct import init_network +from ..losses import SSD + + +def OAI_knees_registration_model(pretrained=True): + # The definition of our final 4 step registration network. + + phi = icon_registration.FunctionFromVectorField( + networks.tallUNet(unet=networks.UNet2ChunkyMiddle, dimension=3) + ) + psi = icon_registration.FunctionFromVectorField(networks.tallUNet2(dimension=3)) + + pretrained_lowres_net = icon_registration.TwoStepRegistration(phi, psi) + + hires_net = icon_registration.TwoStepRegistration( + icon_registration.DownsampleRegistration(pretrained_lowres_net, dimension=3), + icon_registration.FunctionFromVectorField(networks.tallUNet2(dimension=3)), + ) + + fourth_net = icon_registration.InverseConsistentNet( + icon_registration.TwoStepRegistration( + hires_net, + icon_registration.FunctionFromVectorField(networks.tallUNet2(dimension=3)), + ), + SSD(), + 3600, + ) + + BATCH_SIZE = 3 + SCALE = 2 # 1 IS QUARTER RES, 2 IS HALF RES, 4 IS FULL RES + input_shape = [BATCH_SIZE, 1, 40 * SCALE, 96 * SCALE, 96 * SCALE] + + if pretrained: + weights_location = "network_weights/pretrained_OAI_model_compact" + if not exists(weights_location): + print("Downloading pretrained model (500 Mb)") + import urllib.request + import os + + os.makedirs("network_weights", exist_ok=True) + + urllib.request.urlretrieve( + "https://github.com/uncbiag/ICON/releases/download/pretrained_oai_model/OAI_knees_ICON_model.pth", + weights_location, + ) + + trained_weights = torch.load(weights_location, map_location=torch.device("cpu")) + fourth_net.load_state_dict(trained_weights) + + fourth_net.assign_identity_map(input_shape) + + net = fourth_net + net.to(config.device) + net.eval() + return net + + +def OAI_knees_gradICON_model(pretrained=True): + return init_network("knee", pretrained) diff --git a/ICON/icon_registration/pretrained_models/__init__.py b/ICON/icon_registration/pretrained_models/__init__.py new file mode 100644 index 00000000..48d9313b --- /dev/null +++ b/ICON/icon_registration/pretrained_models/__init__.py @@ -0,0 +1,3 @@ +from .OAI_knees import OAI_knees_gradICON_model, OAI_knees_registration_model +from .lung_ct import LungCT_registration_model, lung_network_preprocess +from .HCP_brain import brain_registration_model, brain_network_preprocess diff --git a/ICON/icon_registration/pretrained_models/lung_ct.py b/ICON/icon_registration/pretrained_models/lung_ct.py new file mode 100644 index 00000000..8586daaa --- /dev/null +++ b/ICON/icon_registration/pretrained_models/lung_ct.py @@ -0,0 +1,104 @@ +import itk +import torch + +import icon_registration.config as config + +from .. import losses, network_wrappers, networks + + +def make_network(): + dimension = 3 + inner_net = network_wrappers.FunctionFromVectorField( + networks.tallUNet2(dimension=dimension)) + + for _ in range(2): + inner_net = network_wrappers.TwoStepRegistration( + network_wrappers.DownsampleRegistration(inner_net, + dimension=dimension), + network_wrappers.FunctionFromVectorField( + networks.tallUNet2(dimension=dimension))) + inner_net = network_wrappers.TwoStepRegistration( + inner_net, + network_wrappers.FunctionFromVectorField( + networks.tallUNet2(dimension=dimension))) + + net = losses.GradientICONSparse(inner_net, + similarity=losses.LNCC(sigma=5), + lmbda=1.5) + + return net + + +def init_network(task, pretrained=True): + if task == "lung": + input_shape = [1, 1, 175, 175, 175] + elif task == "knee": + input_shape = [1, 1, 80, 192, 192] + elif task == "brain": + input_shape = [1, 1, 130, 155, 130] + else: + print(f"Task {task} is not defined. Fall back to the lung model.") + task = "lung" + input_shape = [1, 1, 175, 175, 175] + + net = make_network() + net.assign_identity_map(input_shape) + + if pretrained: + from os.path import exists + weights_location = f"network_weights/{task}_model" + + if not exists(f"{weights_location}/{task}_model_weights.trch"): + print("Downloading pretrained model") + import urllib.request + import os + download_path = "https://github.com/uncbiag/ICON/releases/download" + download_path = f"{download_path}/pretrained_models_v1.0.0" + + os.makedirs(weights_location, exist_ok=True) + urllib.request.urlretrieve( + f"{download_path}/{task}_model_weights_step_2.trch", + f"{weights_location}/{task}_model_weights.trch", + ) + + trained_weights = torch.load( + f"{weights_location}/{task}_model_weights.trch", + map_location=torch.device("cpu"), + ) + net.regis_net.load_state_dict(trained_weights, strict=False) + net.assign_identity_map(input_shape) + + net.to(config.device) + net.eval() + return net + + +def lung_network_preprocess(image: "itk.Image", + segmentation: "itk.Image") -> "itk.Image": + + image = itk.clamp_image_filter(image, Bounds=(-1000, 0)) + cast_filter = itk.CastImageFilter[type(image), itk.Image.F3].New() + cast_filter.SetInput(image) + cast_filter.Update() + image = cast_filter.GetOutput() + + segmentation_cast_filter = itk.CastImageFilter[type(segmentation), + itk.Image.F3].New() + segmentation_cast_filter.SetInput(segmentation) + segmentation_cast_filter.Update() + segmentation = segmentation_cast_filter.GetOutput() + + image = itk.shift_scale_image_filter(image, shift=1000, scale=1 / 1000) + + mask_filter = itk.MultiplyImageFilter[itk.Image.F3, itk.Image.F3, + itk.Image.F3].New() + + mask_filter.SetInput1(image) + mask_filter.SetInput2(segmentation) + mask_filter.Update() + + return mask_filter.GetOutput() + + +def LungCT_registration_model(pretrained=True): + return init_network("lung", pretrained=pretrained) diff --git a/ICON/icon_registration/registration_module.py b/ICON/icon_registration/registration_module.py new file mode 100644 index 00000000..b2096f35 --- /dev/null +++ b/ICON/icon_registration/registration_module.py @@ -0,0 +1,114 @@ +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +class RegistrationModule(nn.Module): + r"""Base class for icon modules that perform registration. + + A subclass of RegistrationModule should have a forward method that + takes as input two images image_A and image_B, and returns a python function + phi_AB that transforms a tensor of coordinates. + + RegistrationModule provides a method as_function that turns a tensor + representing an image into a python function mapping a tensor of coordinates + into a tensor of intensities :math:`\mathbb{R}^N \rightarrow \mathbb{R}` . + Mathematically, this is what an image is anyway. + + After this class is constructed, but before it is used, you _must_ call + assign_identity_map on it or on one of its parents to define the coordinate + system associated with input images. + + The contract that a successful registration fulfils is: + for a tensor of coordinates X, self.as_function(image_A)(phi_AB(X)) ~= self.as_function(image_B)(X) + + ie + + .. math:: + I^A \circ \Phi^{AB} \simeq I^B + + In particular, self.as_function(image_A)(phi_AB(self.identity_map)) ~= image_B + """ + + def __init__(self): + super().__init__() + self.downscale_factor = 1 + + def as_function(self, image): + """image is a tensor with shape self.input_shape. + Returns a python function that maps a tensor of coordinates [batch x N_dimensions x ...] + into a tensor of intensities. + """ + + return lambda coordinates: compute_warped_image_multiNC( + image, coordinates, self.spacing, 1 + ) + + def assign_identity_map(self, input_shape, parents_identity_map=None): + self.input_shape = np.array(input_shape) + self.input_shape[0] = 1 + self.spacing = 1.0 / (self.input_shape[2::] - 1) + + # if parents_identity_map is not None: + # self.identity_map = parents_identity_map + # else: + _id = identity_map_multiN(self.input_shape, self.spacing) + self.register_buffer("identity_map", torch.from_numpy(_id), persistent=False) + + if self.downscale_factor != 1: + child_shape = np.concatenate( + [ + self.input_shape[:2], + np.ceil(self.input_shape[2:] / self.downscale_factor).astype(int), + ] + ) + else: + child_shape = self.input_shape + for child in self.children(): + if isinstance(child, RegistrationModule): + child.assign_identity_map( + child_shape, + # None if self.downscale_factor != 1 else self.identity_map, + ) + def make_ddf_from_icon_transform(self, transform): + """Compute A deformation field compatible with monai's Warp + using an ICON transform. The assosciated ICON identity_map is also required + """ + field_0_1 = transform(identity_map) - self.identity_map + network_shape_list = list(self.identity_map.shape[2:]) + scale = torch.Tensor(network_shape_list).to(self.identity_map.device) + + for _ in network_shape_list: + scale = scale[:, None] + scale = scale[None, :] + field_spacing_1 = scale * field_0_1 + return field_spacing_1 + def make_ddf_using_icon_module(self, image_A, image_B): + """Compute a deformation field compatible with monai's Warp + using an ICON RegistrationModule. If the RegistrationModule returns a transform, this function + returns the monai version of that transform. If the RegistrationModule returns a loss, + this function returns a monai version of the internal transform as well as the loss. + """ + + res = self(image_A, image_B) + field = self.make_ddf_from_icon_transform(res["phi_AB"] + ) + return field, res + def forward(image_A, image_B): + """Register a pair of images: + return a python function phi_AB that warps a tensor of coordinates such that + + .. code-block:: python + + self.as_function(image_A)(phi_AB(self.identity_map)) ~= image_B + + .. math:: + I^A \circ \Phi^{AB} \simeq I^B + + :param image_A: the moving image + :param image_B: the fixed image + :return: :math:`\Phi^{AB}` + """ + raise NotImplementedError() + + diff --git a/ICON/icon_registration/similarity.py b/ICON/icon_registration/similarity.py new file mode 100644 index 00000000..60a78153 --- /dev/null +++ b/ICON/icon_registration/similarity.py @@ -0,0 +1,234 @@ + +def normalize(image): + dimension = len(image.shape) - 2 + if dimension == 2: + dim_reduce = [2, 3] + elif dimension == 3: + dim_reduce = [2, 3, 4] + image_centered = image - torch.mean(image, dim_reduce, keepdim=True) + stddev = torch.sqrt(torch.mean(image_centered**2, dim_reduce, keepdim=True)) + return image_centered / stddev + + +class SimilarityBase: + def __init__(self, isInterpolated=False): + self.isInterpolated = isInterpolated + +class NCC(SimilarityBase): + def __init__(self): + super().__init__(isInterpolated=False) + + def __call__(self, image_A, image_B): + assert image_A.shape == image_B.shape, "The shape of image_A and image_B sould be the same." + A = normalize(image_A) + B = normalize(image_B) + res = torch.mean(A * B) + return 1 - res + +# torch removed this function from torchvision.functional_tensor, so we are vendoring it. +def _get_gaussian_kernel1d(kernel_size, sigma): + ksize_half = (kernel_size - 1) * 0.5 + x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) + pdf = torch.exp(-0.5 * (x / sigma).pow(2)) + kernel1d = pdf / pdf.sum() + return kernel1d + +def gaussian_blur(tensor, kernel_size, sigma, padding="same"): + kernel1d = _get_gaussian_kernel1d(kernel_size=kernel_size, sigma=sigma).to( + tensor.device, dtype=tensor.dtype + ) + out = tensor + group = tensor.shape[1] + + if len(tensor.shape) - 2 == 1: + out = torch.conv1d(out, kernel1d[None, None, :].expand(group,-1,-1), padding="same", groups=group) + elif len(tensor.shape) - 2 == 2: + out = torch.conv2d(out, kernel1d[None, None, :, None].expand(group,-1,-1,-1), padding="same", groups=group) + out = torch.conv2d(out, kernel1d[None, None, None, :].expand(group,-1,-1,-1), padding="same", groups=group) + elif len(tensor.shape) - 2 == 3: + out = torch.conv3d(out, kernel1d[None, None, :, None, None].expand(group,-1,-1,-1,-1), padding="same", groups=group) + out = torch.conv3d(out, kernel1d[None, None, None, :, None].expand(group,-1,-1,-1,-1), padding="same", groups=group) + out = torch.conv3d(out, kernel1d[None, None, None, None, :].expand(group,-1,-1,-1,-1), padding="same", groups=group) + + return out + + +class LNCC(SimilarityBase): + def __init__(self, sigma): + super().__init__(isInterpolated=False) + self.sigma = sigma + + def blur(self, tensor): + return gaussian_blur(tensor, self.sigma * 4 + 1, self.sigma) + + def __call__(self, image_A, image_B): + I = image_A + J = image_B + assert I.shape == J.shape, "The shape of image I and J sould be the same." + + return torch.mean( + 1 + - (self.blur(I * J) - (self.blur(I) * self.blur(J))) + / torch.sqrt( + (self.blur(I * I) - self.blur(I) ** 2 + 0.00001) + * (self.blur(J * J) - self.blur(J) ** 2 + 0.00001) + ) + ) + + +class LNCCOnlyInterpolated(SimilarityBase): + def __init__(self, sigma): + super().__init__(isInterpolated=True) + self.sigma = sigma + + def blur(self, tensor): + return gaussian_blur(tensor, self.sigma * 4 + 1, self.sigma) + + def __call__(self, image_A, image_B): + + I = image_A[:, :-1] + J = image_B + + assert I.shape == J.shape, "The shape of image I and J sould be the same." + lncc_everywhere = 1 - ( + self.blur(I * J) - (self.blur(I) * self.blur(J)) + ) / torch.sqrt( + (self.blur(I * I) - self.blur(I) ** 2 + 0.00001) + * (self.blur(J * J) - self.blur(J) ** 2 + 0.00001) + ) + + with torch.no_grad(): + A_inbounds = image_A[:, -1:] + + inbounds_mask = self.blur(A_inbounds) > 0.999 + + if len(image_A.shape) - 2 == 3: + dimensions_to_sum_over = [2, 3, 4] + elif len(image_A.shape) - 2 == 2: + dimensions_to_sum_over = [2, 3] + elif len(image_A.shape) - 2 == 1: + dimensions_to_sum_over = [2] + + lncc_loss = torch.sum( + inbounds_mask * lncc_everywhere, dimensions_to_sum_over + ) / torch.sum(inbounds_mask, dimensions_to_sum_over) + + return torch.mean(lncc_loss) + + +class BlurredSSD(SimilarityBase): + def __init__(self, sigma): + super().__init__(isInterpolated=False) + self.sigma = sigma + + def blur(self, tensor): + return gaussian_blur(tensor, self.sigma * 4 + 1, self.sigma) + + def __call__(self, image_A, image_B): + assert image_A.shape == image_B.shape, "The shape of image_A and image_B sould be the same." + return torch.mean((self.blur(image_A) - self.blur(image_B)) ** 2) + + +class AdaptiveNCC(SimilarityBase): + def __init__(self, level=4, threshold=0.1, gamma=1.5, sigma=2): + super().__init__(isInterpolated=False) + self.level = level + self.threshold = threshold + self.gamma = gamma + self.sigma = sigma + + def blur(self, tensor): + return gaussian_blur(tensor, self.sigma * 2 + 1, self.sigma) + + def __call__(self, image_A, image_B): + assert image_A.shape == image_B.shape, "The shape of image_A and image_B sould be the same." + def _nccBeforeMean(image_A, image_B): + A = normalize(image_A) + B = normalize(image_B) + res = torch.mean(A * B, dim=(1, 2, 3, 4)) + return 1 - res + + sims = [_nccBeforeMean(image_A, image_B)] + for i in range(self.level): + if i == 0: + sims.append(_nccBeforeMean(self.blur(image_A), self.blur(image_B))) + else: + sims.append( + _nccBeforeMean( + self.blur(F.avg_pool3d(image_A, 2**i)), + self.blur(F.avg_pool3d(image_B, 2**i)), + ) + ) + + sim_loss = sims[0] + 0 + lamb_ = 1.0 + for i in range(1, len(sims)): + lamb = torch.clamp( + sims[i].detach() / (self.threshold / (self.gamma ** (len(sims) - i))), + 0, + 1, + ) + sim_loss = lamb * sims[i] + (1 - lamb) * sim_loss + lamb_ *= 1 - lamb + + return torch.mean(sim_loss) + +class SSD(SimilarityBase): + def __init__(self): + super().__init__(isInterpolated=False) + + def __call__(self, image_A, image_B): + assert image_A.shape == image_B.shape, "The shape of image_A and image_B sould be the same." + return torch.mean((image_A - image_B) ** 2) + +class SSDOnlyInterpolated(SimilarityBase): + def __init__(self): + super().__init__(isInterpolated=True) + + def __call__(self, image_A, image_B): + if len(image_A.shape) - 2 == 3: + dimensions_to_sum_over = [2, 3, 4] + elif len(image_A.shape) - 2 == 2: + dimensions_to_sum_over = [2, 3] + elif len(image_A.shape) - 2 == 1: + dimensions_to_sum_over = [2] + + inbounds_mask = image_A[:, -1:] + image_A = image_A[:, :-1] + assert image_A.shape == image_B.shape, "The shape of image_A and image_B sould be the same." + + inbounds_squared_distance = inbounds_mask * (image_A - image_B) ** 2 + sum_squared_distance = torch.sum(inbounds_squared_distance, dimensions_to_sum_over) + divisor = torch.sum(inbounds_mask, dimensions_to_sum_over) + ssds = sum_squared_distance / divisor + return torch.mean(ssds) + + +def flips(phi, in_percentage=False): + if len(phi.size()) == 5: + a = (phi[:, :, 1:, 1:, 1:] - phi[:, :, :-1, 1:, 1:]).detach() + b = (phi[:, :, 1:, 1:, 1:] - phi[:, :, 1:, :-1, 1:]).detach() + c = (phi[:, :, 1:, 1:, 1:] - phi[:, :, 1:, 1:, :-1]).detach() + + dV = torch.sum(torch.cross(a, b, 1) * c, axis=1, keepdims=True) + if in_percentage: + return torch.mean((dV < 0).float()) * 100. + else: + return torch.sum(dV < 0) / phi.shape[0] + elif len(phi.size()) == 4: + du = (phi[:, :, 1:, :-1] - phi[:, :, :-1, :-1]).detach() + dv = (phi[:, :, :-1, 1:] - phi[:, :, :-1, :-1]).detach() + dA = du[:, 0] * dv[:, 1] - du[:, 1] * dv[:, 0] + if in_percentage: + return torch.mean((dA < 0).float()) * 100. + else: + return torch.sum(dA < 0) / phi.shape[0] + elif len(phi.size()) == 3: + du = (phi[:, :, 1:] - phi[:, :, :-1]).detach() + if in_percentage: + return torch.mean((du < 0).float()) * 100. + else: + return torch.sum(du < 0) / phi.shape[0] + else: + raise ValueError() + diff --git a/ICON/icon_registration/test_utils.py b/ICON/icon_registration/test_utils.py new file mode 100644 index 00000000..758e7849 --- /dev/null +++ b/ICON/icon_registration/test_utils.py @@ -0,0 +1,64 @@ +import pathlib +import subprocess +import sys +import numpy as np + +TEST_DATA_DIR = pathlib.Path(__file__).parent.parent.parent / "test_files" + + +def download_test_data(): + subprocess.run( + [ + "girder-client", + "--api-url", + "https://data.kitware.com/api/v1", + "localsync", + "61d3a99d4acac99f429277d7", + str(TEST_DATA_DIR), + ], + #stdout=sys.stdout, + ) + + +COPD_spacing = { + "copd1": [0.625, 0.625, 2.5], + "copd2": [0.645, 0.645, 2.5], + "copd3": [0.652, 0.652, 2.5], + "copd4": [0.590, 0.590, 2.5], + "copd5": [0.647, 0.647, 2.5], + "copd6": [0.633, 0.633, 2.5], + "copd7": [0.625, 0.625, 2.5], + "copd8": [0.586, 0.586, 2.5], + "copd9": [0.664, 0.664, 2.5], + "copd10": [0.742, 0.742, 2.5], +} + + +def read_copd_pointset(f_path): + """Points are deliminated by '\n' and X,Y,Z of each point are deliminated by '\t'. + + :param f_path: the path to the file containing + the position of points from copdgene dataset. + :return: numpy array of points in physical coordinates + """ + spacing = COPD_spacing[f_path.split("/")[-1].split("_")[0]] + spacing = np.expand_dims(spacing, 0) + with open(f_path) as fp: + content = fp.read().split("\n") + + # Read number of points from second + count = len(content) - 1 + + # Read the points + points = np.ndarray([count, 3], dtype=np.float64) + for i in range(count): + if content[i] == "": + break + temp = content[i].split("\t") + points[i, 0] = float(temp[0]) + points[i, 1] = float(temp[1]) + points[i, 2] = float(temp[2]) + + # The copd gene points are in index space instead of physical space. + # Move them to physical space. + return (points - 1) * spacing diff --git a/ICON/icon_registration/train.py b/ICON/icon_registration/train.py new file mode 100644 index 00000000..873e12da --- /dev/null +++ b/ICON/icon_registration/train.py @@ -0,0 +1,121 @@ +from datetime import datetime + +import torch +import tqdm + +from .losses import ICONLoss, to_floats +import icon_registration.config + +def write_stats(writer, stats: ICONLoss, ite): + for k, v in to_floats(stats)._asdict().items(): + writer.add_scalar(k, v, ite) + + +def train_batchfunction( + net, + optimizer, + make_batch, + steps=100000, + step_callback=(lambda net: None), + unwrapped_net=None, +): + """A training function intended for long running experiments, with tensorboard logging + and model checkpoints. Use for medical registration training + """ + import footsteps + from torch.utils.tensorboard import SummaryWriter + + if unwrapped_net is None: + unwrapped_net = net + + loss_curve = [] + writer = SummaryWriter( + footsteps.output_dir + "/" + datetime.now().strftime("%Y%m%d-%H%M%S"), + flush_secs=30, + ) + + visualization_moving, visualization_fixed = [m[:4] for m in make_batch()] + for iteration in range(0, steps): + optimizer.zero_grad() + moving_image, fixed_image = make_batch() + loss_object = net(moving_image, fixed_image) + loss = torch.mean(loss_object.all_loss) + loss.backward() + + step_callback(unwrapped_net) + + print(to_floats(loss_object)) + write_stats(writer, loss_object, iteration) + optimizer.step() + + if iteration % 300 == 0: + torch.save( + optimizer.state_dict(), + footsteps.output_dir + "optimizer_weights_" + str(iteration), + ) + torch.save( + unwrapped_net.regis_net.state_dict(), + footsteps.output_dir + "network_weights_" + str(iteration), + ) + unwrapped_net.eval() + print("val (from train set)") + warped = [] + with torch.no_grad(): + for i in range(4): + print( unwrapped_net(visualization_moving[i:i + 1], visualization_fixed[i:i + 1])) + warped.append(unwrapped_net.warped_image_A.cpu()) + warped = torch.cat(warped) + unwrapped_net.train() + + def render(im): + if len(im.shape) == 5: + im = im[:, :, :, :, im.shape[4] // 2] + if torch.min(im) < 0: + im = im - torch.min(im) + if torch.max(im) > 1: + im = im / torch.max(im) + return im[:4, [0, 0, 0]].detach().cpu() + + writer.add_images( + "moving_image", render(visualization_moving[:4]), iteration, dataformats="NCHW" + ) + writer.add_images( + "fixed_image", render(visualization_fixed[:4]), iteration, dataformats="NCHW" + ) + writer.add_images( + "warped_moving_image", + render(warped), + iteration, + dataformats="NCHW", + ) + writer.add_images( + "difference", + render(torch.clip((warped[:4, :1] - visualization_fixed[:4, :1].cpu()) + 0.5, 0, 1)), + iteration, + dataformats="NCHW", + ) + + + + +def train_datasets(net, optimizer, d1, d2, epochs=400): + """A training function for quick experiments""" + batch_size = net.identity_map.shape[0] + loss_history = [] + for epoch in tqdm.tqdm(range(epochs)): + for A, B in list(zip(d1, d2)): + if True: # A[0].size()[0] == batch_size: + image_A = A[0].to(icon_registration.config.device) + image_B = B[0].to(icon_registration.config.device) + optimizer.zero_grad() + + loss_object = net(image_A, image_B) + + loss_object.all_loss.backward() + optimizer.step() + + loss_history.append(to_floats(loss_object)) + return loss_history + + +train2d = train_datasets From e7d90ce019108fdd0f4070a3336fb68eea84d433 Mon Sep 17 00:00:00 2001 From: Hastings Greer Date: Fri, 15 Dec 2023 13:00:37 -0500 Subject: [PATCH 02/15] move files --- ICON/{icon_registration => icon}/__init__.py | 0 ICON/{icon_registration => icon}/config.py | 0 ICON/{icon_registration => icon}/data.py | 0 ICON/{icon_registration => icon}/itk_wrapper.py | 0 ICON/{icon_registration => icon}/losses.py | 0 ICON/{icon_registration => icon}/network_wrappers.py | 0 ICON/{icon_registration => icon}/networks.py | 0 ICON/{icon_registration => icon}/pretrained_models/HCP_brain.py | 0 ICON/{icon_registration => icon}/pretrained_models/OAI_knees.py | 0 ICON/{icon_registration => icon}/pretrained_models/__init__.py | 0 ICON/{icon_registration => icon}/pretrained_models/lung_ct.py | 0 ICON/{icon_registration => icon}/registration_module.py | 0 ICON/{icon_registration => icon}/similarity.py | 0 ICON/{icon_registration => icon}/test_utils.py | 0 ICON/{icon_registration => icon}/train.py | 0 15 files changed, 0 insertions(+), 0 deletions(-) rename ICON/{icon_registration => icon}/__init__.py (100%) rename ICON/{icon_registration => icon}/config.py (100%) rename ICON/{icon_registration => icon}/data.py (100%) rename ICON/{icon_registration => icon}/itk_wrapper.py (100%) rename ICON/{icon_registration => icon}/losses.py (100%) rename ICON/{icon_registration => icon}/network_wrappers.py (100%) rename ICON/{icon_registration => icon}/networks.py (100%) rename ICON/{icon_registration => icon}/pretrained_models/HCP_brain.py (100%) rename ICON/{icon_registration => icon}/pretrained_models/OAI_knees.py (100%) rename ICON/{icon_registration => icon}/pretrained_models/__init__.py (100%) rename ICON/{icon_registration => icon}/pretrained_models/lung_ct.py (100%) rename ICON/{icon_registration => icon}/registration_module.py (100%) rename ICON/{icon_registration => icon}/similarity.py (100%) rename ICON/{icon_registration => icon}/test_utils.py (100%) rename ICON/{icon_registration => icon}/train.py (100%) diff --git a/ICON/icon_registration/__init__.py b/ICON/icon/__init__.py similarity index 100% rename from ICON/icon_registration/__init__.py rename to ICON/icon/__init__.py diff --git a/ICON/icon_registration/config.py b/ICON/icon/config.py similarity index 100% rename from ICON/icon_registration/config.py rename to ICON/icon/config.py diff --git a/ICON/icon_registration/data.py b/ICON/icon/data.py similarity index 100% rename from ICON/icon_registration/data.py rename to ICON/icon/data.py diff --git a/ICON/icon_registration/itk_wrapper.py b/ICON/icon/itk_wrapper.py similarity index 100% rename from ICON/icon_registration/itk_wrapper.py rename to ICON/icon/itk_wrapper.py diff --git a/ICON/icon_registration/losses.py b/ICON/icon/losses.py similarity index 100% rename from ICON/icon_registration/losses.py rename to ICON/icon/losses.py diff --git a/ICON/icon_registration/network_wrappers.py b/ICON/icon/network_wrappers.py similarity index 100% rename from ICON/icon_registration/network_wrappers.py rename to ICON/icon/network_wrappers.py diff --git a/ICON/icon_registration/networks.py b/ICON/icon/networks.py similarity index 100% rename from ICON/icon_registration/networks.py rename to ICON/icon/networks.py diff --git a/ICON/icon_registration/pretrained_models/HCP_brain.py b/ICON/icon/pretrained_models/HCP_brain.py similarity index 100% rename from ICON/icon_registration/pretrained_models/HCP_brain.py rename to ICON/icon/pretrained_models/HCP_brain.py diff --git a/ICON/icon_registration/pretrained_models/OAI_knees.py b/ICON/icon/pretrained_models/OAI_knees.py similarity index 100% rename from ICON/icon_registration/pretrained_models/OAI_knees.py rename to ICON/icon/pretrained_models/OAI_knees.py diff --git a/ICON/icon_registration/pretrained_models/__init__.py b/ICON/icon/pretrained_models/__init__.py similarity index 100% rename from ICON/icon_registration/pretrained_models/__init__.py rename to ICON/icon/pretrained_models/__init__.py diff --git a/ICON/icon_registration/pretrained_models/lung_ct.py b/ICON/icon/pretrained_models/lung_ct.py similarity index 100% rename from ICON/icon_registration/pretrained_models/lung_ct.py rename to ICON/icon/pretrained_models/lung_ct.py diff --git a/ICON/icon_registration/registration_module.py b/ICON/icon/registration_module.py similarity index 100% rename from ICON/icon_registration/registration_module.py rename to ICON/icon/registration_module.py diff --git a/ICON/icon_registration/similarity.py b/ICON/icon/similarity.py similarity index 100% rename from ICON/icon_registration/similarity.py rename to ICON/icon/similarity.py diff --git a/ICON/icon_registration/test_utils.py b/ICON/icon/test_utils.py similarity index 100% rename from ICON/icon_registration/test_utils.py rename to ICON/icon/test_utils.py diff --git a/ICON/icon_registration/train.py b/ICON/icon/train.py similarity index 100% rename from ICON/icon_registration/train.py rename to ICON/icon/train.py From 817096a6cd0b4967298218cc79a0571c5b3af216 Mon Sep 17 00:00:00 2001 From: Hastings Greer Date: Mon, 18 Dec 2023 13:51:17 -0500 Subject: [PATCH 03/15] WIP squashme --- ICON/LICENSE | 13 + ICON/{icon => icon_registration}/__init__.py | 0 ICON/{icon => icon_registration}/config.py | 0 ICON/{icon => icon_registration}/data.py | 0 .../itk_wrapper.py | 0 ICON/{icon => icon_registration}/losses.py | 0 .../network_wrappers.py | 16 +- ICON/{icon => icon_registration}/networks.py | 0 .../pretrained_models/HCP_brain.py | 0 .../pretrained_models/OAI_knees.py | 0 .../pretrained_models/__init__.py | 0 .../pretrained_models/lung_ct.py | 0 .../registration_module.py | 46 +-- .../{icon => icon_registration}/similarity.py | 0 .../{icon => icon_registration}/test_utils.py | 0 ICON/{icon => icon_registration}/train.py | 0 ICON/pyproject.toml | 6 + ICON/requirements.txt | 8 + ICON/setup.cfg | 30 ++ ICON/setup.py | 3 + ICON/test/__init__.py | 0 ICON/test/test_2d_registration_train.py | 110 +++++++ ICON/test/test_brain_itk.py | 79 +++++ ICON/test/test_imports.py | 28 ++ ICON/test/test_knee_itk.py | 126 +++++++ ICON/test/test_knee_registration.py | 183 +++++++++++ ICON/test/test_losses.py | 127 ++++++++ ICON/test/test_lung_itk.py | 126 +++++++ ICON/test/test_lung_registration.py | 308 ++++++++++++++++++ 29 files changed, 1184 insertions(+), 25 deletions(-) create mode 100644 ICON/LICENSE rename ICON/{icon => icon_registration}/__init__.py (100%) rename ICON/{icon => icon_registration}/config.py (100%) rename ICON/{icon => icon_registration}/data.py (100%) rename ICON/{icon => icon_registration}/itk_wrapper.py (100%) rename ICON/{icon => icon_registration}/losses.py (100%) rename ICON/{icon => icon_registration}/network_wrappers.py (94%) rename ICON/{icon => icon_registration}/networks.py (100%) rename ICON/{icon => icon_registration}/pretrained_models/HCP_brain.py (100%) rename ICON/{icon => icon_registration}/pretrained_models/OAI_knees.py (100%) rename ICON/{icon => icon_registration}/pretrained_models/__init__.py (100%) rename ICON/{icon => icon_registration}/pretrained_models/lung_ct.py (100%) rename ICON/{icon => icon_registration}/registration_module.py (75%) rename ICON/{icon => icon_registration}/similarity.py (100%) rename ICON/{icon => icon_registration}/test_utils.py (100%) rename ICON/{icon => icon_registration}/train.py (100%) create mode 100644 ICON/pyproject.toml create mode 100644 ICON/requirements.txt create mode 100644 ICON/setup.cfg create mode 100644 ICON/setup.py create mode 100644 ICON/test/__init__.py create mode 100644 ICON/test/test_2d_registration_train.py create mode 100644 ICON/test/test_brain_itk.py create mode 100644 ICON/test/test_imports.py create mode 100644 ICON/test/test_knee_itk.py create mode 100644 ICON/test/test_knee_registration.py create mode 100644 ICON/test/test_losses.py create mode 100644 ICON/test/test_lung_itk.py create mode 100644 ICON/test/test_lung_registration.py diff --git a/ICON/LICENSE b/ICON/LICENSE new file mode 100644 index 00000000..db8ea3b3 --- /dev/null +++ b/ICON/LICENSE @@ -0,0 +1,13 @@ +Copyright 2017, Hastings Greer + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/ICON/icon/__init__.py b/ICON/icon_registration/__init__.py similarity index 100% rename from ICON/icon/__init__.py rename to ICON/icon_registration/__init__.py diff --git a/ICON/icon/config.py b/ICON/icon_registration/config.py similarity index 100% rename from ICON/icon/config.py rename to ICON/icon_registration/config.py diff --git a/ICON/icon/data.py b/ICON/icon_registration/data.py similarity index 100% rename from ICON/icon/data.py rename to ICON/icon_registration/data.py diff --git a/ICON/icon/itk_wrapper.py b/ICON/icon_registration/itk_wrapper.py similarity index 100% rename from ICON/icon/itk_wrapper.py rename to ICON/icon_registration/itk_wrapper.py diff --git a/ICON/icon/losses.py b/ICON/icon_registration/losses.py similarity index 100% rename from ICON/icon/losses.py rename to ICON/icon_registration/losses.py diff --git a/ICON/icon/network_wrappers.py b/ICON/icon_registration/network_wrappers.py similarity index 94% rename from ICON/icon/network_wrappers.py rename to ICON/icon_registration/network_wrappers.py index 7ed1d2a9..1ed8e40c 100644 --- a/ICON/icon/network_wrappers.py +++ b/ICON/icon_registration/network_wrappers.py @@ -144,7 +144,17 @@ def forward(self, image_A, image_B): image_A = self.avg_pool(image_A, 2, ceil_mode=True) image_B = self.avg_pool(image_B, 2, ceil_mode=True) - return self.net(image_A, image_B) + result = self.net(image_A, image_B) + + # MONAI's ddf coordinate convention depends on resolution: + + for key in ["phi_AB", "phi_BA"]: + if key in lowres_result: + highres_phi = lambda coords: 2 * result[key](coords / 2) + result[key] = highres_phi + + return result + class InverseConsistentVelocityField(RegistrationModule): @@ -213,11 +223,11 @@ def transform_AB(tensor_of_coordinates): coordinates_homogeneous = torch.cat( [ tensor_of_coordinates, - torch.ones(shape, device=tensor_of_coordinates.device), + self.torch.ones(shape, device=tensor_of_coordinates.device), ], axis=1, ) - return imultiply_matrix_vectorfield( + return multiply_matrix_vectorfield( matrix_phi, coordinates_homogeneous )[:, :-1] diff --git a/ICON/icon/networks.py b/ICON/icon_registration/networks.py similarity index 100% rename from ICON/icon/networks.py rename to ICON/icon_registration/networks.py diff --git a/ICON/icon/pretrained_models/HCP_brain.py b/ICON/icon_registration/pretrained_models/HCP_brain.py similarity index 100% rename from ICON/icon/pretrained_models/HCP_brain.py rename to ICON/icon_registration/pretrained_models/HCP_brain.py diff --git a/ICON/icon/pretrained_models/OAI_knees.py b/ICON/icon_registration/pretrained_models/OAI_knees.py similarity index 100% rename from ICON/icon/pretrained_models/OAI_knees.py rename to ICON/icon_registration/pretrained_models/OAI_knees.py diff --git a/ICON/icon/pretrained_models/__init__.py b/ICON/icon_registration/pretrained_models/__init__.py similarity index 100% rename from ICON/icon/pretrained_models/__init__.py rename to ICON/icon_registration/pretrained_models/__init__.py diff --git a/ICON/icon/pretrained_models/lung_ct.py b/ICON/icon_registration/pretrained_models/lung_ct.py similarity index 100% rename from ICON/icon/pretrained_models/lung_ct.py rename to ICON/icon_registration/pretrained_models/lung_ct.py diff --git a/ICON/icon/registration_module.py b/ICON/icon_registration/registration_module.py similarity index 75% rename from ICON/icon/registration_module.py rename to ICON/icon_registration/registration_module.py index b2096f35..43168e45 100644 --- a/ICON/icon/registration_module.py +++ b/ICON/icon_registration/registration_module.py @@ -3,6 +3,9 @@ import torch.nn.functional as F from torch import nn +from monai.networks.blocks import Warp + + class RegistrationModule(nn.Module): r"""Base class for icon modules that perform registration. @@ -33,27 +36,34 @@ class RegistrationModule(nn.Module): def __init__(self): super().__init__() self.downscale_factor = 1 + self.warp = Warp() + self.identity_map = None def as_function(self, image): - """image is a tensor with shape self.input_shape. + """image is a (potentially vector valued) tensor with shape self.input_shape. Returns a python function that maps a tensor of coordinates [batch x N_dimensions x ...] - into a tensor of intensities. + into a tensor of the intensity of `image` at `coordinates`. + + This allows translating the standard notation of registration papers more literally into code. + + I \\circ \\Phi , the standard mathematical notation for a warped image, has the type + "function from coordinates to intensities" and can be translated to the python code + + warped_image = lambda coords: self.as_function(I)(phi(coords)) + + Often, this should actually be left as a function. If a tensor is needed, conversion is: + + warped_image_tensor = warped_image(self.identity_map) """ - return lambda coordinates: compute_warped_image_multiNC( - image, coordinates, self.spacing, 1 + return lambda coordinates: self.warp( + image, coordinates - self.identity_map ) def assign_identity_map(self, input_shape, parents_identity_map=None): - self.input_shape = np.array(input_shape) - self.input_shape[0] = 1 - self.spacing = 1.0 / (self.input_shape[2::] - 1) - - # if parents_identity_map is not None: - # self.identity_map = parents_identity_map - # else: - _id = identity_map_multiN(self.input_shape, self.spacing) - self.register_buffer("identity_map", torch.from_numpy(_id), persistent=False) + self.input_shape = input_shape + _id = self.warp.get_reference_grid(input_shape) + self.register_buffer("identity_map", _id, persistent=False) if self.downscale_factor != 1: child_shape = np.concatenate( @@ -74,15 +84,7 @@ def make_ddf_from_icon_transform(self, transform): """Compute A deformation field compatible with monai's Warp using an ICON transform. The assosciated ICON identity_map is also required """ - field_0_1 = transform(identity_map) - self.identity_map - network_shape_list = list(self.identity_map.shape[2:]) - scale = torch.Tensor(network_shape_list).to(self.identity_map.device) - - for _ in network_shape_list: - scale = scale[:, None] - scale = scale[None, :] - field_spacing_1 = scale * field_0_1 - return field_spacing_1 + return transform(self.identity_map) - self.identity_map def make_ddf_using_icon_module(self, image_A, image_B): """Compute a deformation field compatible with monai's Warp using an ICON RegistrationModule. If the RegistrationModule returns a transform, this function diff --git a/ICON/icon/similarity.py b/ICON/icon_registration/similarity.py similarity index 100% rename from ICON/icon/similarity.py rename to ICON/icon_registration/similarity.py diff --git a/ICON/icon/test_utils.py b/ICON/icon_registration/test_utils.py similarity index 100% rename from ICON/icon/test_utils.py rename to ICON/icon_registration/test_utils.py diff --git a/ICON/icon/train.py b/ICON/icon_registration/train.py similarity index 100% rename from ICON/icon/train.py rename to ICON/icon_registration/train.py diff --git a/ICON/pyproject.toml b/ICON/pyproject.toml new file mode 100644 index 00000000..b3887988 --- /dev/null +++ b/ICON/pyproject.toml @@ -0,0 +1,6 @@ +[build-system] +requires = [ + "setuptools>42", + "wheel" +] +build-backend = "setuptools.build_meta" diff --git a/ICON/requirements.txt b/ICON/requirements.txt new file mode 100644 index 00000000..5d357d03 --- /dev/null +++ b/ICON/requirements.txt @@ -0,0 +1,8 @@ +torch +torchvision +tensorboard +tqdm +matplotlib +footsteps>=0.1.6 +itk==5.3.0 +girder_client==3.1.8 diff --git a/ICON/setup.cfg b/ICON/setup.cfg new file mode 100644 index 00000000..83119cf1 --- /dev/null +++ b/ICON/setup.cfg @@ -0,0 +1,30 @@ +[metadata] +name = icon +version = 1.1.2 +author = Hastings Greer +author_email = t@hgreer.com +description = A package for image registration regularized by inverse consistency +long_description = file: README.md +long_description_content_type = text/markdown +classifiers = + Programming Language :: Python :: 3 + License :: OSI Approved :: Apache Software License + Operating System :: POSIX :: Linux + +[options] +packages = icon_registration +python_requires = >=3.4 + +install_requires = + torch + torchvision + tensorboard + tqdm + matplotlib + footsteps>=0.1.6 + itk==5.3.0 + girder_client==3.1.8 + +[options.packages.find] +where = src + diff --git a/ICON/setup.py b/ICON/setup.py new file mode 100644 index 00000000..60684932 --- /dev/null +++ b/ICON/setup.py @@ -0,0 +1,3 @@ +from setuptools import setup + +setup() diff --git a/ICON/test/__init__.py b/ICON/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ICON/test/test_2d_registration_train.py b/ICON/test/test_2d_registration_train.py new file mode 100644 index 00000000..e1583afc --- /dev/null +++ b/ICON/test/test_2d_registration_train.py @@ -0,0 +1,110 @@ +import unittest + + +class Test2DRegistrationTrain(unittest.TestCase): + def test_2d_registration_train(self): + import icon_registration + + import icon_registration.data as data + import icon_registration.networks as networks + from icon_registration import SSD + + import numpy as np + import torch + import random + import os + + random.seed(1) + torch.manual_seed(1) + torch.cuda.manual_seed(1) + np.random.seed(1) + + batch_size = 128 + + d1, d2 = data.get_dataset_triangles( + data_size=50, hollow=False, batch_size=batch_size + ) + d1_t, d2_t = data.get_dataset_triangles( + data_size=50, hollow=False, batch_size=batch_size + ) + + lmbda = 2048 + + print("ICON training") + net = icon_registration.InverseConsistentNet( + icon_registration.FunctionFromVectorField(networks.tallUNet2(dimension=2)), + # Our image similarity metric. The last channel of x and y is whether the value is interpolated or extrapolated, + # which is used by some metrics but not this one + SSD(), + lmbda, + ) + + input_shape = list(next(iter(d1))[0].size()) + input_shape[0] = 1 + net.assign_identity_map(input_shape) + net.cuda() + optimizer = torch.optim.Adam(net.parameters(), lr=0.001) + net.train() + + y = icon_registration.train_datasets(net, optimizer, d1, d2, epochs=5) + + # Test that image similarity is good enough + self.assertLess(np.mean(np.array(y)[-5:, 1]), 0.1) + + # Test that folds are rare enough + self.assertLess(np.mean(np.exp(np.array(y)[-5:, 3] - 0.1)), 2) + for l in y: + print(l) + + def test_2d_registration_train_GradICON(self): + import icon_registration + + import icon_registration.data as data + import icon_registration.networks as networks + from icon_registration import SSD + + import numpy as np + import torch + import random + import os + + random.seed(1) + torch.manual_seed(1) + torch.cuda.manual_seed(1) + np.random.seed(1) + + batch_size = 128 + + d1, d2 = data.get_dataset_triangles( + data_size=50, hollow=False, batch_size=batch_size + ) + d1_t, d2_t = data.get_dataset_triangles( + data_size=50, hollow=False, batch_size=batch_size + ) + + lmbda = 1.0 + + print("GradientICON training") + net = icon_registration.GradientICON( + icon_registration.FunctionFromVectorField(networks.tallUNet2(dimension=2)), + # Our image similarity metric. The last channel of x and y is whether the value is interpolated or extrapolated, + # which is used by some metrics but not this one + SSD(), + lmbda, + ) + + input_shape = next(iter(d1))[0].size() + net.assign_identity_map(input_shape) + net.cuda() + optimizer = torch.optim.Adam(net.parameters(), lr=0.001) + net.train() + + y = icon_registration.train_datasets(net, optimizer, d1, d2, epochs=5) + + # Test that image similarity is good enough + self.assertLess(np.mean(np.array(y)[-5:, 1]), 0.1) + + # Test that folds are rare enough + self.assertLess(np.mean(np.exp(np.array(y)[-5:, 3] - 0.1)), 2) + for l in y: + print(l) diff --git a/ICON/test/test_brain_itk.py b/ICON/test/test_brain_itk.py new file mode 100644 index 00000000..9423aff5 --- /dev/null +++ b/ICON/test/test_brain_itk.py @@ -0,0 +1,79 @@ +import itk +import numpy as np +import unittest +import matplotlib.pyplot as plt +import numpy as np + +import icon_registration.test_utils +import icon_registration.pretrained_models +import icon_registration.itk_wrapper + + +class TestItkRegistration(unittest.TestCase): + def test_itk_registration(self): + print("brain GradICON") + import os + + os.environ["FOOTSTEPS_NAME"] = "test" + import footsteps + + icon_registration.test_utils.download_test_data() + + model = icon_registration.pretrained_models.brain_registration_model( + pretrained=True + ) + + image_A = itk.imread( + f"{icon_registration.test_utils.TEST_DATA_DIR}/brain_test_data/2_T1w_acpc_dc_restore_brain.nii.gz" + ) + + image_B = itk.imread( + f"{icon_registration.test_utils.TEST_DATA_DIR}/brain_test_data/8_T1w_acpc_dc_restore_brain.nii.gz" + ) + + image_A_processed = icon_registration.pretrained_models.brain_network_preprocess( + image_A + ) + + image_B_processed = icon_registration.pretrained_models.brain_network_preprocess( + image_B + ) + + phi_AB, phi_BA = icon_registration.itk_wrapper.register_pair( + model, image_A_processed, image_B_processed + ) + + assert isinstance(phi_AB, itk.CompositeTransform) + interpolator = itk.LinearInterpolateImageFunction.New(image_A) + + warped_image_A = itk.resample_image_filter( + image_A_processed, + transform=phi_AB, + interpolator=interpolator, + size=itk.size(image_B), + output_spacing=itk.spacing(image_B), + output_direction=image_B.GetDirection(), + output_origin=image_B.GetOrigin(), + ) + + plt.imshow( + np.array(itk.checker_board_image_filter(warped_image_A, image_B_processed))[40] + ) + plt.colorbar() + plt.savefig(footsteps.output_dir + "grid.png") + plt.clf() + plt.imshow(np.array(warped_image_A)[40]) + plt.savefig(footsteps.output_dir + "warped.png") + plt.clf() + + + reference = np.load(icon_registration.test_utils.TEST_DATA_DIR / "brain_test_data/2_and_8_warped_itkfix.npy") + + np.save( + footsteps.output_dir + "warped_brain.npy", + itk.array_from_image(warped_image_A)[40], + ) + + self.assertLess( + np.mean(np.abs(reference - itk.array_from_image(warped_image_A)[40])), 1e-5 + ) diff --git a/ICON/test/test_imports.py b/ICON/test/test_imports.py new file mode 100644 index 00000000..535097c4 --- /dev/null +++ b/ICON/test/test_imports.py @@ -0,0 +1,28 @@ +import unittest + + +class TestImports(unittest.TestCase): + def test_imports_CPU(self): + import icon_registration + import icon_registration.networks + import icon_registration.data + + def test_requirements_match_cfg_CPU(self): + from inspect import getsourcefile + import os.path as path, sys + import configparser + + current_dir = path.dirname(path.abspath(getsourcefile(lambda: 0))) + parent_dir = current_dir[: current_dir.rfind(path.sep)] + + with open(parent_dir + "/requirements.txt") as f: + requirements_txt = "\n" + f.read() + requirements_cfg = configparser.ConfigParser() + requirements_cfg.read(parent_dir + "/setup.cfg") + requirements_cfg = requirements_cfg["options"]["install_requires"] + "\n" + self.assertEqual(requirements_txt, requirements_cfg) + + def test_pytorch_cuda(self): + import torch + + x = torch.Tensor([7]).cuda diff --git a/ICON/test/test_knee_itk.py b/ICON/test/test_knee_itk.py new file mode 100644 index 00000000..204ce9ab --- /dev/null +++ b/ICON/test/test_knee_itk.py @@ -0,0 +1,126 @@ +import itk +import numpy as np +import unittest +import matplotlib.pyplot as plt +import numpy as np + +import icon_registration.test_utils +import icon_registration.pretrained_models +import icon_registration.itk_wrapper + + +class TestItkRegistration(unittest.TestCase): + def test_itk_registration(self): + import os + + os.environ["FOOTSTEPS_NAME"] = "test" + import footsteps + + icon_registration.test_utils.download_test_data() + + model = icon_registration.pretrained_models.OAI_knees_registration_model( + pretrained=True + ) + + image_A = itk.imread( + str( + icon_registration.test_utils.TEST_DATA_DIR + / "knees_diverse_sizes" + / + # "9126260_20060921_SAG_3D_DESS_LEFT_11309302_image.nii.gz") + "9487462_20081003_SAG_3D_DESS_RIGHT_11495603_image.nii.gz" + ) + ) + + image_B = itk.imread( + str( + icon_registration.test_utils.TEST_DATA_DIR + / "knees_diverse_sizes" + / "9225063_20090413_SAG_3D_DESS_RIGHT_12784112_image.nii.gz" + ) + ) + print(image_A.GetLargestPossibleRegion().GetSize()) + print(image_B.GetLargestPossibleRegion().GetSize()) + print(image_A.GetSpacing()) + print(image_B.GetSpacing()) + + phi_AB, phi_BA = icon_registration.itk_wrapper.register_pair( + model, image_A, image_B + ) + + assert isinstance(phi_AB, itk.CompositeTransform) + interpolator = itk.LinearInterpolateImageFunction.New(image_A) + + warped_image_A = itk.resample_image_filter( + image_A, + transform=phi_AB, + interpolator=interpolator, + size=itk.size(image_B), + output_spacing=itk.spacing(image_B), + output_direction=image_B.GetDirection(), + output_origin=image_B.GetOrigin(), + ) + + plt.imshow( + np.array(itk.checker_board_image_filter(warped_image_A, image_B))[40] + ) + plt.colorbar() + plt.savefig(footsteps.output_dir + "grid.png") + plt.clf() + plt.imshow(np.array(warped_image_A)[40]) + plt.savefig(footsteps.output_dir + "warped.png") + plt.clf() + + reference = np.load(icon_registration.test_utils.TEST_DATA_DIR / "warped_itkfix.npy") + + np.save( + footsteps.output_dir + "warped_knee.npy", + itk.array_from_image(warped_image_A)[40], + ) + + self.assertLess( + np.mean(np.abs(reference - itk.array_from_image(warped_image_A)[40])), 1e-6 + ) + def test_itk_consistency(self): + import torch + + img = torch.tensor([[1, 2, 3], [4, 5.0, 6], [7, 8, 13]])[None, None] + + warper = icon_registration.RegistrationModule() + + warper.assign_identity_map([1, 1, 3, 3]) + + deformation = torch.zeros((1, 2, 3, 3)) + + deformation[0, 0, 0, 0] = 0.5 + + deformation[0, 1, 2, 2] = -0.5 + + img_func = warper.as_function(img) + + warped_image_torch = img_func(warper.identity_map + deformation) + + itk_img = itk.image_from_array(img[0, 0]) + + sp = itk.Vector[itk.D, 2]((4.2, 6.9)) + + itk_img.SetSpacing(sp) + + phi = icon_registration.itk_wrapper.create_itk_transform( + deformation + warper.identity_map, warper.identity_map, itk_img, itk_img + ) + + interpolator = itk.LinearInterpolateImageFunction.New(itk_img) + + warped_image_A = itk.resample_image_filter( + itk_img, + transform=phi, + interpolator=interpolator, + use_reference_image=True, + reference_image=itk_img, + ) + + self.assertTrue( + np.all(np.array(warped_image_A) == np.array(warped_image_torch)) + ) + diff --git a/ICON/test/test_knee_registration.py b/ICON/test/test_knee_registration.py new file mode 100644 index 00000000..609ceb42 --- /dev/null +++ b/ICON/test/test_knee_registration.py @@ -0,0 +1,183 @@ +import unittest + + +class TestKneeRegistration(unittest.TestCase): + def test_knee_registration(self): + print("OAI ICON") + + import icon_registration.pretrained_models + from icon_registration.mermaidlite import compute_warped_image_multiNC + from icon_registration.losses import flips + + import torch + import numpy as np + + import subprocess + + print("Downloading test data)") + import icon_registration.test_utils + + icon_registration.test_utils.download_test_data() + t_ds = torch.load( + icon_registration.test_utils.TEST_DATA_DIR / "icon_example_data" + ) + batched_ds = list(zip(*[t_ds[i::2] for i in range(2)])) + net = icon_registration.pretrained_models.OAI_knees_registration_model( + pretrained=True + ) + # Run on the four downloaded image pairs + + with torch.no_grad(): + dices = [] + folds_list = [] + + for x in batched_ds[:]: + # Seperate the image data used for registration from the segmentation used for evaluation, + # and shape it for passing to the network + x = list(zip(*x)) + x = [torch.cat(r, 0).cuda().float() for r in x] + fixed_image, fixed_cartilage = x[0], x[2] + moving_image, moving_cartilage = x[1], x[3] + + # Run the registration. + # Our network expects batches of two pairs, + # moving_image.size = torch.Size([2, 1, 80, 192, 192]) + # fixed_image.size = torch.Size([2, 1, 80, 192, 192]) + # intensity normalized to have min 0 and max 1. + + net(moving_image, fixed_image) + + # Once registration is run, net.phi_AB and net.phi_BA are functions that map + # tensors of coordinates from image B to A and A to B respectively. + + # Evaluate the registration + # First, evaluate phi_AB on a tensor of coordinates to get an explicit map. + phi_AB_vectorfield = net.phi_AB(net.identity_map) + fat_phi = torch.nn.Upsample( + size=moving_cartilage.size()[2:], + mode="trilinear", + align_corners=False, + )(phi_AB_vectorfield[:, :3]) + sz = np.array(fat_phi.size()) + spacing = 1.0 / (sz[2::] - 1) + + # Warp the cartilage of one image to match the other using the explicit map. + warped_moving_cartilage = compute_warped_image_multiNC( + moving_cartilage.float(), fat_phi, spacing, 1 + ) + + # Binarize the segmentations + wmb = warped_moving_cartilage > 0.5 + fb = fixed_cartilage > 0.5 + + # Compute the dice metric + intersection = wmb * fb + dice = ( + 2 + * torch.sum(intersection, [1, 2, 3, 4]).float() + / (torch.sum(wmb, [1, 2, 3, 4]) + torch.sum(fb, [1, 2, 3, 4])) + ) + print("Batch DICE:", dice) + dices.append(dice) + + # Compute the folds metric + f = [flips(phi[None]).item() for phi in phi_AB_vectorfield] + print("Batch folds per image:", f) + folds_list.append(f) + + mean_dice = torch.mean(torch.cat(dices).cpu()) + print("Mean DICE SCORE:", mean_dice) + self.assertTrue(mean_dice.item() > 0.68) + mean_folds = np.mean(folds_list) + print("Mean folds per image:", mean_folds) + self.assertTrue(mean_folds < 300) + + def test_knee_registration_gradICON(self): + print("OAI gradICON") + + import icon_registration.pretrained_models + from icon_registration.mermaidlite import compute_warped_image_multiNC + from icon_registration.losses import flips + + import torch + import numpy as np + + import subprocess + + print("Downloading test data)") + import icon_registration.test_utils + + icon_registration.test_utils.download_test_data() + t_ds = torch.load( + icon_registration.test_utils.TEST_DATA_DIR / "icon_example_data" + ) + batched_ds = list(zip(*[t_ds[i::2] for i in range(2)])) + net = icon_registration.pretrained_models.OAI_knees_gradICON_model( + pretrained=True + ) + # Run on the four downloaded image pairs + + with torch.no_grad(): + dices = [] + folds_list = [] + + for x in batched_ds[:]: + # Seperate the image data used for registration from the segmentation used for evaluation, + # and shape it for passing to the network + x = list(zip(*x)) + x = [torch.cat(r, 0).cuda().float() for r in x] + fixed_image, fixed_cartilage = x[0], x[2] + moving_image, moving_cartilage = x[1], x[3] + + # Run the registration. + # Our network expects batches of two pairs, + # moving_image.size = torch.Size([2, 1, 80, 192, 192]) + # fixed_image.size = torch.Size([2, 1, 80, 192, 192]) + # intensity normalized to have min 0 and max 1. + + net(moving_image, fixed_image) + + # Once registration is run, net.phi_AB and net.phi_BA are functions that map + # tensors of coordinates from image B to A and A to B respectively. + + # Evaluate the registration + # First, evaluate phi_AB on a tensor of coordinates to get an explicit map. + phi_AB_vectorfield = net.phi_AB(net.identity_map) + fat_phi = torch.nn.Upsample( + size=moving_cartilage.size()[2:], + mode="trilinear", + align_corners=False, + )(phi_AB_vectorfield[:, :3]) + sz = np.array(fat_phi.size()) + spacing = 1.0 / (sz[2::] - 1) + + # Warp the cartilage of one image to match the other using the explicit map. + warped_moving_cartilage = compute_warped_image_multiNC( + moving_cartilage.float(), fat_phi, spacing, 1 + ) + + # Binarize the segmentations + wmb = warped_moving_cartilage > 0.5 + fb = fixed_cartilage > 0.5 + + # Compute the dice metric + intersection = wmb * fb + dice = ( + 2 + * torch.sum(intersection, [1, 2, 3, 4]).float() + / (torch.sum(wmb, [1, 2, 3, 4]) + torch.sum(fb, [1, 2, 3, 4])) + ) + print("Batch DICE:", dice) + dices.append(dice) + + # Compute the folds metric + f = [flips(phi[None]).item() for phi in phi_AB_vectorfield] + print("Batch folds per image:", f) + folds_list.append(f) + + mean_dice = torch.mean(torch.cat(dices).cpu()) + print("Mean DICE SCORE:", mean_dice) + self.assertTrue(mean_dice.item() > 0.68) + mean_folds = np.mean(folds_list) + print("Mean folds per image:", mean_folds) + self.assertTrue(mean_folds < 300) diff --git a/ICON/test/test_losses.py b/ICON/test/test_losses.py new file mode 100644 index 00000000..2fed0f38 --- /dev/null +++ b/ICON/test/test_losses.py @@ -0,0 +1,127 @@ +import torch +from icon_registration.losses import AdaptiveNCC, normalize +import math +import numbers +import torch +from torch import nn + +import unittest + +class TestLosses(unittest.TestCase): + def test_adaptive_ncc(self): + a = torch.rand(1,1,64,64,64) + b = torch.rand(1,1,64,64,64) + + sim = AdaptiveNCC() + l = sim(a, b) + sim_origin = adaptive_ncc() + l_origin = sim_origin(a, b) + + self.assertAlmostEqual(l.item(), l_origin.item(), places=5) + +class GaussianSmoothing(nn.Module): + """ + Apply gaussian smoothing on a + 1d, 2d or 3d tensor. Filtering is performed seperately for each channel + in the input using a depthwise convolution. + Arguments: + channels (int, sequence): Number of channels of the input tensors. Output will + have this number of channels as well. + kernel_size (int, sequence): Size of the gaussian kernel. + sigma (float, sequence): Standard deviation of the gaussian kernel. + dim (int, optional): The number of dimensions of the data. + Default value is 2 (spatial). + """ + + def __init__(self, channels, kernel_size, sigma, dim=2): + super(GaussianSmoothing, self).__init__() + if isinstance(kernel_size, numbers.Number): + kernel_size = [kernel_size] * dim + if isinstance(sigma, numbers.Number): + sigma = [sigma] * dim + + # The gaussian kernel is the product of the + # gaussian function of each dimension. + kernel = 1 + meshgrids = torch.meshgrid( + [torch.arange(size, dtype=torch.float32) for size in kernel_size] + ) + for size, std, mgrid in zip(kernel_size, sigma, meshgrids): + mean = (size - 1) / 2 + kernel *= ( + 1 + / (std * math.sqrt(2 * math.pi)) + * torch.exp(-(((mgrid - mean) / std) ** 2) / 2) + ) + + # Make sure sum of values in gaussian kernel equals 1. + kernel = kernel / torch.sum(kernel) + + # Reshape to depthwise convolutional weight + kernel = kernel.view(1, 1, *kernel.size()) + kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) + + self.register_buffer("weight", kernel) + self.groups = channels + + if dim == 1: + self.conv = nn.functional.conv1d + elif dim == 2: + self.conv = nn.functional.conv2d + elif dim == 3: + self.conv = nn.functional.conv3d + else: + raise RuntimeError( + "Only 1, 2 and 3 dimensions are supported. Received {}.".format(dim) + ) + + def forward(self, input): + """ + Apply gaussian filter to input. + Arguments: + input (torch.Tensor): Input to apply gaussian filter on. + Returns: + filtered (torch.Tensor): Filtered output. + """ + return self.conv( + input, + weight=self.weight, + groups=self.groups, + padding=int(self.weight.shape[2] / 2), + ) + +def adaptive_ncc( + level=4, threshold=0.1, gamma=1.5, smoother=GaussianSmoothing(1, 5, 2, 3)): + def _nccBeforeMean(image_A, image_B): + A = normalize(image_A[:, :1]) + B = normalize(image_B) + res = torch.mean(A * B, dim=(1, 2, 3, 4)) + return 1 - res + + def _sim(x, y): + sims = [_nccBeforeMean(x, y)] + for i in range(level): + if i == 0: + sims.append(_nccBeforeMean(smoother(x), smoother(y))) + else: + sims.append( + _nccBeforeMean( + smoother(nn.functional.avg_pool3d(x, 2**i)), + smoother(nn.functional.avg_pool3d(y, 2**i)), + ) + ) + + sim_loss = sims[0] + 0 + lamb_ = 1.0 + for i in range(1, len(sims)): + lamb = torch.clamp( + sims[i].detach() / (threshold / (gamma ** (len(sims) - i))), 0, 1 + ) + sim_loss = lamb * sims[i] + (1 - lamb) * sim_loss + lamb_ *= 1 - lamb + + return torch.mean(sim_loss) + + return _sim + + \ No newline at end of file diff --git a/ICON/test/test_lung_itk.py b/ICON/test/test_lung_itk.py new file mode 100644 index 00000000..f87e0e1e --- /dev/null +++ b/ICON/test/test_lung_itk.py @@ -0,0 +1,126 @@ +import itk +import numpy as np +import unittest +import matplotlib.pyplot as plt +import numpy as np + +import icon_registration.test_utils +import icon_registration.pretrained_models +import icon_registration.itk_wrapper + + +class TestItkRegistration(unittest.TestCase): + def test_itk_registration(self): + + model = icon_registration.pretrained_models.LungCT_registration_model( + pretrained=True + ) + + icon_registration.test_utils.download_test_data() + + image_exp = itk.imread( + str( + icon_registration.test_utils.TEST_DATA_DIR + / "lung_test_data/copd1_highres_EXP_STD_COPD_img.nii.gz" + ) + ) + image_insp = itk.imread( + str( + icon_registration.test_utils.TEST_DATA_DIR + / "lung_test_data/copd1_highres_INSP_STD_COPD_img.nii.gz" + ) + ) + image_exp_seg = itk.imread( + str( + icon_registration.test_utils.TEST_DATA_DIR + / "lung_test_data/copd1_highres_EXP_STD_COPD_label.nii.gz" + ) + ) + image_insp_seg = itk.imread( + str( + icon_registration.test_utils.TEST_DATA_DIR + / "lung_test_data/copd1_highres_INSP_STD_COPD_label.nii.gz" + ) + ) + + image_insp_preprocessed = ( + icon_registration.pretrained_models.lung_network_preprocess( + image_insp, image_insp_seg + ) + ) + image_exp_preprocessed = ( + icon_registration.pretrained_models.lung_network_preprocess( + image_exp, image_exp_seg + ) + ) + + phi_AB, phi_BA = icon_registration.itk_wrapper.register_pair( + model, image_insp_preprocessed, image_exp_preprocessed, finetune_steps=None + ) + + assert isinstance(phi_AB, itk.CompositeTransform) + interpolator = itk.LinearInterpolateImageFunction.New(image_insp_preprocessed) + + warped_image_insp_preprocessed = itk.resample_image_filter( + image_insp_preprocessed, + transform=phi_AB, + interpolator=interpolator, + size=itk.size(image_exp_preprocessed), + output_spacing=itk.spacing(image_exp_preprocessed), + output_direction=image_exp_preprocessed.GetDirection(), + output_origin=image_exp_preprocessed.GetOrigin(), + ) + + # log some images to show the registration + import os + os.environ["FOOTSTEPS_NAME"] = "test" + import footsteps + + plt.imshow( + np.array( + itk.checker_board_image_filter( + warped_image_insp_preprocessed, image_exp_preprocessed + ) + )[140] + ) + plt.colorbar() + plt.savefig(footsteps.output_dir + "grid_lung.png") + plt.clf() + plt.imshow(np.array(warped_image_insp_preprocessed)[140]) + plt.colorbar() + plt.savefig(footsteps.output_dir + "warped_lung.png") + plt.clf() + plt.imshow( + np.array(warped_image_insp_preprocessed)[140] + - np.array(image_exp_preprocessed)[140] + ) + plt.colorbar() + plt.savefig(footsteps.output_dir + "difference_lung.png") + plt.clf() + + insp_points = icon_registration.test_utils.read_copd_pointset( + "test_files/lung_test_data/copd1_300_iBH_xyz_r1.txt" + ) + exp_points = icon_registration.test_utils.read_copd_pointset( + "test_files/lung_test_data/copd1_300_eBH_xyz_r1.txt" + ) + dists = [] + for i in range(len(insp_points)): + px, py = ( + exp_points[i], + np.array(phi_BA.TransformPoint(tuple(insp_points[i]))), + ) + dists.append(np.sqrt(np.sum((px - py) ** 2))) + print(np.mean(dists)) + + self.assertLess(np.mean(dists), 1.5) + dists = [] + for i in range(len(insp_points)): + px, py = ( + insp_points[i], + np.array(phi_AB.TransformPoint(tuple(exp_points[i]))), + ) + dists.append(np.sqrt(np.sum((px - py) ** 2))) + print(np.mean(dists)) + self.assertLess(np.mean(dists), 2.3) + diff --git a/ICON/test/test_lung_registration.py b/ICON/test/test_lung_registration.py new file mode 100644 index 00000000..f998264a --- /dev/null +++ b/ICON/test/test_lung_registration.py @@ -0,0 +1,308 @@ +import os +import unittest + +import itk +import numpy as np +import torch +import torch.nn.functional as F +import tqdm + +import icon_registration +import icon_registration.pretrained_models + +COPD_spacing = { + "copd1": [0.625, 0.625, 2.5], + "copd2": [0.645, 0.645, 2.5], + "copd3": [0.652, 0.652, 2.5], + "copd4": [0.590, 0.590, 2.5], + "copd5": [0.647, 0.647, 2.5], + "copd6": [0.633, 0.633, 2.5], + "copd7": [0.625, 0.625, 2.5], + "copd8": [0.586, 0.586, 2.5], + "copd9": [0.664, 0.664, 2.5], + "copd10": [0.742, 0.742, 2.5], +} + + +def readPoint(f_path): + """ + :param f_path: the path to the file containing the position of points. + Points are deliminated by '\n' and X,Y,Z of each point are deliminated by '\t'. + :return: numpy list of positions. + """ + with open(f_path) as fp: + content = fp.read().split("\n") + + # Read number of points from second + count = len(content) - 1 + + # Read the points + points = np.ndarray([count, 3], dtype=np.float64) + + for i in range(count): + if content[i] == "": + break + temp = content[i].split("\t") + points[i, 0] = float(temp[0]) + points[i, 1] = float(temp[1]) + points[i, 2] = float(temp[2]) + + return points + + +def calc_warped_points(source_list_t, phi_t, dim, spacing, phi_spacing): + """ + :param source_list_t: source image. + :param phi_t: the inversed displacement. Domain in source coordinate. + :param dim: voxel dimenstions. + :param spacing: image spacing. + :return: a N*3 tensor containg warped positions in the physical coordinate. + """ + warped_list_t = F.grid_sample(phi_t, source_list_t, align_corners=True) + + warped_list_t = torch.flip(warped_list_t.permute(0, 2, 3, 4, 1), [4])[0, 0, 0] + warped_list_t = torch.mul( + torch.mul(warped_list_t, torch.from_numpy(dim - 1.0)), + torch.from_numpy(phi_spacing), + ) + + return warped_list_t + + +def eval_with_data(source_list, target_list, phi, dim, spacing, origin, phi_spacing): + """ + :param source_list: a numpy list of markers' position in source image. + :param target_list: a numpy list of markers' position in target image. + :param phi: displacement map in numpy format. + :param dim: voxel dimenstions. + :param spacing: image spacing. + :param return: res, [dist_x, dist_y, dist_z] res is the distance between + the warped points and target points in MM. [dist_x, dist_y, dist_z] are + distances in MM along x,y,z axis perspectively. + """ + origin_list = np.repeat( + [ + origin, + ], + target_list.shape[0], + axis=0, + ) + + # Translate landmark from landmark coord to phi coordinate + target_list_t = ( + torch.from_numpy((target_list - 1.0) * spacing) - origin_list * phi_spacing + ) + source_list_t = ( + torch.from_numpy((source_list - 1.0) * spacing) - origin_list * phi_spacing + ) + + # Pay attention to the definition of align_corners in grid_sampling. + # Translate landmarks to voxel index in image space [-1, 1] + source_list_norm = source_list_t / phi_spacing / (dim - 1.0) * 2.0 - 1.0 + source_list_norm = source_list_norm.unsqueeze(0).unsqueeze(0).unsqueeze(0) + + phi_t = torch.from_numpy(phi).double() + + warped_list_t = calc_warped_points( + source_list_norm, phi_t, dim, spacing, phi_spacing + ) + + pdist = torch.nn.PairwiseDistance(p=2) + dist = pdist(target_list_t, warped_list_t) + idx = torch.argsort(dist).numpy() + dist_x = torch.mean(torch.abs(target_list_t[:, 0] - warped_list_t[:, 0])).item() + dist_y = torch.mean(torch.abs(target_list_t[:, 1] - warped_list_t[:, 1])).item() + dist_z = torch.mean(torch.abs(target_list_t[:, 2] - warped_list_t[:, 2])).item() + res = torch.mean(dist).item() + + return res, [dist_x, dist_y, dist_z] + + +def compute_dice(x, y): + eps = 1e-11 + y_loc = set(np.where(y.flatten() == 1)[0]) + x_loc = set(np.where(x.flatten() == 1)[0]) + # iou + intersection = set.intersection(x_loc, y_loc) + # recall + len_intersection = len(intersection) + tp = float(len_intersection) + fn = float(len(y_loc) - len_intersection) + fp = float(len(x_loc) - len_intersection) + + if len(y_loc) != 0 or len(x_loc) != 0: + return 2 * tp / (2 * tp + fn + fp + eps) + + return 0.0 + + +def eval_copd_highres(seg_folder, case_id, phi, phi_inv, phi_spacing, origin, dim): + """ + :param dataset_path: the path to the dataset folder. The folder structure assumption: + dataset_path/landmarks stores landmarks files; dataset_path/segments stores segmentation maps + + :param case_id: a list of case id + :param phi: phi defined in expiration image domain. Numpy array. Bx3xDxWxH. Orientation: SI, AP, RL orientation + :param phi_inv: phi_inv defined in inspiration image domain. Numpy array. Bx3xDxWxH + :param phi_spacing: physical spacing of the domain of phi. Numpy array. Bx3xDxWxH + """ + + def _eval(phi, case_id, phi_spacing, origin, dim, inv=False): + result = {} + if inv: + source_file = os.path.join(seg_folder, f"{case_id}_300_iBH_xyz_r1.txt") + target_file = os.path.join(seg_folder, f"{case_id}_300_eBH_xyz_r1.txt") + else: + source_file = os.path.join(seg_folder, f"{case_id}_300_eBH_xyz_r1.txt") + target_file = os.path.join(seg_folder, f"{case_id}_300_iBH_xyz_r1.txt") + + spacing = COPD_spacing[case_id] + + source_list = readPoint(source_file) + target_list = readPoint(target_file) + + mTRE, mTRE_seperate = eval_with_data( + source_list, target_list, phi, dim, spacing, origin, phi_spacing + ) + + result["mTRE"] = mTRE + result["mTRE_X"] = mTRE_seperate[0] + result["mTRE_Y"] = mTRE_seperate[1] + result["mTRE_Z"] = mTRE_seperate[2] + + return result + + results = {} + for i, case in enumerate(case_id): + results[case] = _eval( + phi[i : i + 1], case, phi_spacing[i], origin[i], dim[i], inv=False + ) + results[f"{case}_inv"] = _eval( + phi_inv[i : i + 1], case, phi_spacing[i], origin[i], dim[i], inv=True + ) + return results + + +class TestLungRegistration(unittest.TestCase): + def test_lung_registration(self): + print("lung gradICON") + + net = icon_registration.pretrained_models.LungCT_registration_model( + pretrained=True + ) + icon_registration.test_utils.download_test_data() + + root = str(icon_registration.test_utils.TEST_DATA_DIR / "lung_test_data") + cases = [f"copd{i}_highres" for i in range(1, 2)] + hu_clip_range = [-1000, 0] + + # Load data + def process(iA, isSeg=False): + iA = iA[None, None, :, :, :] + if isSeg: + iA = iA.float() + iA = torch.nn.functional.max_pool3d(iA, 2) + iA[iA > 0] = 1 + else: + iA = torch.clip(iA, hu_clip_range[0], hu_clip_range[1]) + iA = iA / 1000 + + iA = torch.nn.functional.avg_pool3d(iA, 2) + return iA + + def read_itk(path): + img = itk.imread(path) + return torch.tensor(np.asarray(img)), np.flipud(list(img.GetOrigin())) + + dirlab = [] + dirlab_seg = [] + dirlab_origin = [] + for name in tqdm.tqdm(list(iter(cases))[:]): + image_insp, _ = read_itk(f"{root}/{name}_INSP_STD_COPD_img.nii.gz") + image_exp, _ = read_itk(f"{root}/{name}_EXP_STD_COPD_img.nii.gz") + seg_insp, _ = read_itk(f"{root}/{name}_INSP_STD_COPD_label.nii.gz") + seg_exp, origin = read_itk(f"{root}/{name}_EXP_STD_COPD_label.nii.gz") + + # dirlab.append((process(image_insp), process(image_exp))) + dirlab.append( + ( + ((process(image_insp) + 1) * process(seg_insp, True)), + ((process(image_exp) + 1) * process(seg_exp, True)), + ) + ) + dirlab_seg.append((process(seg_insp, True), process(seg_exp, True))) + dirlab_origin.append(origin) + + def make_batch(dataset, dataset_seg): + image_A = torch.cat([p[0] for p in dataset]).cuda() + image_B = torch.cat([p[1] for p in dataset]).cuda() + + image_A_seg = torch.cat([p[0] for p in dataset_seg]).cuda() + image_B_seg = torch.cat([p[1] for p in dataset_seg]).cuda() + return image_A, image_B, image_A_seg, image_B_seg + + net.cuda() + + train_A, train_B, train_A_seg, train_B_seg = make_batch(dirlab, dirlab_seg) + + phis = [] + phis_inv = [] + warped_A = [] + for i in range(train_A.shape[0]): + with torch.no_grad(): + print(net(train_A[i : i + 1], train_B[i : i + 1])) + + phis.append((net.phi_AB_vectorfield.detach() * 2.0 - 1.0)) + phis_inv.append((net.phi_BA_vectorfield.detach() * 2.0 - 1.0)) + warped_A.append(net.warped_image_A[:, 0:1].detach()) + + warped_A = torch.cat(warped_A) + + phis_np = (torch.cat(phis).cpu().numpy() + 1.0) / 2.0 + phis_inv_np = (torch.cat(phis_inv).cpu().numpy() + 1.0) / 2.0 + res = eval_copd_highres( + root, + [c[:-8] for c in cases], + phis_np, + phis_inv_np, + np.repeat(np.array([[1.0, 1.0, 1.0]]), len(cases), axis=0), + np.array([np.flipud(i) for i in dirlab_origin]), + np.repeat(np.array([[350, 350, 350]]), len(cases), axis=0), + ) + + results = [] + results_inv = [] + for k, v in res.items(): + result = [] + for m, n in v.items(): + result.append(n) + if "_inv" in k: + results_inv.append(result) + else: + results.append(result) + results = np.array(results) + + results_str = f"mTRE: {results[:,0].mean()}, mTRE_X: {results[:,1].mean()}, mTRE_Y: {results[:,2].mean()}, mTRE_Z: {results[:,3].mean()}" + print(results_str) + results_inv = np.array(results_inv) + results_inv_str = f"mTRE: {results_inv[:,0].mean()}, mTRE_X: {results_inv[:,1].mean()}, mTRE_Y: {results_inv[:,2].mean()}, mTRE_Z: {results_inv[:,3].mean()}" + print(results_inv_str) + + # Compute Dice + + warped_train_A_seg = F.grid_sample( + train_A_seg.float(), + torch.cat(phis).flip([1]).permute([0, 2, 3, 4, 1]), + padding_mode="zeros", + mode="nearest", + align_corners=True, + ) + dices = [] + for i in range(warped_train_A_seg.shape[0]): + dices.append( + compute_dice( + train_B_seg[i].cpu().numpy(), + warped_train_A_seg[i].detach().cpu().numpy(), + ) + ) + print(np.array(dices).mean()) From 8b8f7baa198b5a12fbedc32ebee48bc3c497058c Mon Sep 17 00:00:00 2001 From: Hastings Greer Date: Tue, 19 Dec 2023 15:21:54 -0500 Subject: [PATCH 04/15] refactor losses --- ICON/icon_registration/losses.py | 436 ++++++--------------- ICON/icon_registration/network_wrappers.py | 1 + ICON/test/test_losses.py | 127 ------ 3 files changed, 127 insertions(+), 437 deletions(-) delete mode 100644 ICON/test/test_losses.py diff --git a/ICON/icon_registration/losses.py b/ICON/icon_registration/losses.py index 0cc0a2bb..7587524c 100644 --- a/ICON/icon_registration/losses.py +++ b/ICON/icon_registration/losses.py @@ -6,44 +6,15 @@ from icon_registration import config, network_wrappers -from .mermaidlite import compute_warped_image_multiNC +import registration_module - -def to_floats(stats): - out = [] - for v in stats: - if isinstance(v, torch.Tensor): - v = torch.mean(v).cpu().item() - out.append(v) - return ICONLoss(*out) - -class ICON(network_wrappers.RegistrationModule): +class Loss(registration_module.RegistrationModule): def __init__(self, network, similarity, lmbda): - super().__init__() - self.regis_net = network self.lmbda = lmbda self.similarity = similarity - - def __call__(self, image_A, image_B) -> ICONLoss: - return super().__call__(image_A, image_B) - - def forward(self, image_A, image_B): - - assert self.identity_map.shape[2:] == image_A.shape[2:] - assert self.identity_map.shape[2:] == image_B.shape[2:] - - # Tag used elsewhere for optimization. - # Must be set at beginning of forward b/c not preserved by .cuda() etc - self.identity_map.isIdentity = True - - self.phi_AB = self.regis_net(image_A, image_B) - self.phi_BA = self.regis_net(image_B, image_A) - - self.phi_AB_vectorfield = self.phi_AB(self.identity_map) - self.phi_BA_vectorfield = self.phi_BA(self.identity_map) - + def compute_similarity(self, image_A, image_B, phi_AB_vectorfield): if getattr(self.similarity, "isInterpolated", False): # tag images during warping so that the similarity measure # can use information about whether a sample is interpolated @@ -58,32 +29,51 @@ def forward(self, image_A, image_B): else: inbounds_tag = None - self.warped_image_A = compute_warped_image_multiNC( - torch.cat([image_A, inbounds_tag], axis=1) if inbounds_tag is not None else image_A, + warped_image_A = self.as_function( + torch.cat([image_A, inbounds_tag], axis=1) if inbounds_tag is not None else image_A)( self.phi_AB_vectorfield, - self.spacing, - 1, ) - self.warped_image_B = compute_warped_image_multiNC( - torch.cat([image_B, inbounds_tag], axis=1) if inbounds_tag is not None else image_B, - self.phi_BA_vectorfield, - self.spacing, - 1, + similarity_loss = self.similarity( + warped_image_A, image_B ) + return {"similarity_loss":similarity_loss, "warped_image_A":warped_image_A} - similarity_loss = self.similarity( - self.warped_image_A, image_B - ) + self.similarity(self.warped_image_B, image_A) +class TwoWayRegularizer(Loss): + def forward(self, image_A, image_B): + assert self.identity_map.shape[2:] == image_A.shape[2:] + assert self.identity_map.shape[2:] == image_B.shape[2:] + + # Tag used elsewhere for optimization. + # Must be set at beginning of forward b/c not preserved by .cuda() etc + self.identity_map.isIdentity = True + + phi_AB = self.regis_net(image_A, image_B)["phi_AB"] + phi_BA = self.regis_net(image_B, image_A)["phi_AB"] + phi_AB_vectorfield = phi_AB(self.identity_map) + phi_BA_vectorfield = phi_BA(self.identity_map) + + similarity_AB = self.compute_similarity( image_A, image_B, phi_AB_vectorfield) + similarity_BA = self.compute_similarity( image_B, image_A, phi_BA_vectorfield) + + similarity_loss = similarity_AB["similarity_loss"] + similarity_BA["similarity_loss"] + regularization_loss = compute_regularizer(self, phi_AB, phi_BA) + + all_loss = self.lmbda * gradient_inverse_consistency_loss + similarity_loss + + negative_jacobian_voxels = flips(phi_BA_vectorfield) + + return {"all_loss": all_loss, "regularization_loss": inverse_consistency_loss, "similarity_loss": similarity_loss, "phi_AB": phi_AB, "phi_BA": phi_BA, "warped_image_A": similarity_AB["warped_image_A"], "warped_image_B": similiarity_BA["warped_image_A"], "negative_jacobian_voxels":negative_jacobian_voxels} + + +class ICON(TwoWayRegularizer): + + def compute_regularizer(self, phi_AB, phi_BA): Iepsilon = ( self.identity_map + torch.randn(*self.identity_map.shape).to(image_A.device) - * 1 - / self.identity_map.shape[-1] ) - # inverse consistency one way - approximate_Iepsilon1 = self.phi_AB(self.phi_BA(Iepsilon)) approximate_Iepsilon2 = self.phi_BA(self.phi_AB(Iepsilon)) @@ -92,29 +82,22 @@ def forward(self, image_A, image_B): (Iepsilon - approximate_Iepsilon1) ** 2 ) + torch.mean((Iepsilon - approximate_Iepsilon2) ** 2) - transform_magnitude = torch.mean( - (self.identity_map - self.phi_AB_vectorfield) ** 2 - ) - - all_loss = self.lmbda * inverse_consistency_loss + similarity_loss + inverse_consistency_loss /= self.input_shape[2] ** 2 + + return inverse_consistency_loss - return ICONLoss( - all_loss, - inverse_consistency_loss, - similarity_loss, - transform_magnitude, - flips(self.phi_BA_vectorfield), - ) -class GradICON(network_wrappers.RegistrationModule): - def compute_gradient_icon_loss(self, phi_AB, phi_BA): +class GradICON(TwoWayRegularizer): + def compute_regularizer(self, phi_AB, phi_BA): Iepsilon = ( self.identity_map + torch.randn(*self.identity_map.shape).to(self.identity_map.device) - * 1 - / self.identity_map.shape[-1] ) + if len(self.input_shape) - 2 == 3: + Iepsilon =Iepsilon [:, :, ::2, ::2, ::2] + elif len(self.input_shape) - 2 == 2: + Iepsilon = Iepsilon[:, :, ::2, ::2] # compute squared Frobenius of Jacobian of icon error @@ -154,84 +137,12 @@ def compute_gradient_icon_loss(self, phi_AB, phi_BA): ) / delta direction_losses.append(torch.mean(grad_d_icon_error**2)) - inverse_consistency_loss = sum(direction_losses) - - return inverse_consistency_loss - - def compute_similarity_measure(self, phi_AB, phi_BA, image_A, image_B): - self.phi_AB_vectorfield = phi_AB(self.identity_map) - self.phi_BA_vectorfield = phi_BA(self.identity_map) - - if getattr(self.similarity, "isInterpolated", False): - # tag images during warping so that the similarity measure - # can use information about whether a sample is interpolated - # or extrapolated - inbounds_tag = torch.zeros([image_A.shape[0]] + [1] + list(image_A.shape[2:]), device=image_A.device) - if len(self.input_shape) - 2 == 3: - inbounds_tag[:, :, 1:-1, 1:-1, 1:-1] = 1.0 - elif len(self.input_shape) - 2 == 2: - inbounds_tag[:, :, 1:-1, 1:-1] = 1.0 - else: - inbounds_tag[:, :, 1:-1] = 1.0 - else: - inbounds_tag = None - - self.warped_image_A = self.as_function( - torch.cat([image_A, inbounds_tag], axis=1) if inbounds_tag is not None else image_A - )(self.phi_AB_vectorfield) - self.warped_image_B = self.as_function( - torch.cat([image_B, inbounds_tag], axis=1) if inbounds_tag is not None else image_B - )(self.phi_BA_vectorfield) - similarity_loss = self.similarity( - self.warped_image_A, image_B - ) + self.similarity(self.warped_image_B, image_A) - return similarity_loss - - def forward(self, image_A, image_B) -> ICONLoss: - - assert self.identity_map.shape[2:] == image_A.shape[2:] - assert self.identity_map.shape[2:] == image_B.shape[2:] - - # Tag used elsewhere for optimization. - # Must be set at beginning of forward b/c not preserved by .cuda() etc - self.identity_map.isIdentity = True - - self.phi_AB = self.regis_net(image_A, image_B) - self.phi_BA = self.regis_net(image_B, image_A) - - similarity_loss = self.compute_similarity_measure( - self.phi_AB, self.phi_BA, image_A, image_B - ) - - inverse_consistency_loss = self.compute_gradient_icon_loss( - self.phi_AB, self.phi_BA - ) - - all_loss = self.lmbda * inverse_consistency_loss + similarity_loss - - transform_magnitude = torch.mean( - (self.identity_map - self.phi_AB_vectorfield) ** 2 - ) - return ICONLoss( - all_loss, - inverse_consistency_loss, - similarity_loss, - transform_magnitude, - flips(self.phi_BA_vectorfield), - ) - - -class GradientICONSparse(network_wrappers.RegistrationModule): - def __init__(self, network, similarity, lmbda): + gradient_inverse_consistency_loss = sum(direction_losses) - super().__init__() - - self.regis_net = network - self.lmbda = lmbda - self.similarity = similarity + return gradient_inverse_consistency_loss +class OneWayRegularizer(Loss): def forward(self, image_A, image_B): - assert self.identity_map.shape[2:] == image_A.shape[2:] assert self.identity_map.shape[2:] == image_B.shape[2:] @@ -239,122 +150,29 @@ def forward(self, image_A, image_B): # Must be set at beginning of forward b/c not preserved by .cuda() etc self.identity_map.isIdentity = True - self.phi_AB = self.regis_net(image_A, image_B) - self.phi_BA = self.regis_net(image_B, image_A) - - self.phi_AB_vectorfield = self.phi_AB(self.identity_map) - self.phi_BA_vectorfield = self.phi_BA(self.identity_map) - - # tag images during warping so that the similarity measure - # can use information about whether a sample is interpolated - # or extrapolated - - if getattr(self.similarity, "isInterpolated", False): - # tag images during warping so that the similarity measure - # can use information about whether a sample is interpolated - # or extrapolated - inbounds_tag = torch.zeros([image_A.shape[0]] + [1] + list(image_A.shape[2:]), device=image_A.device) - if len(self.input_shape) - 2 == 3: - inbounds_tag[:, :, 1:-1, 1:-1, 1:-1] = 1.0 - elif len(self.input_shape) - 2 == 2: - inbounds_tag[:, :, 1:-1, 1:-1] = 1.0 - else: - inbounds_tag[:, :, 1:-1] = 1.0 - else: - inbounds_tag = None - - self.warped_image_A = compute_warped_image_multiNC( - torch.cat([image_A, inbounds_tag], axis=1) if inbounds_tag is not None else image_A, - self.phi_AB_vectorfield, - self.spacing, - 1, - ) - self.warped_image_B = compute_warped_image_multiNC( - torch.cat([image_B, inbounds_tag], axis=1) if inbounds_tag is not None else image_B, - self.phi_BA_vectorfield, - self.spacing, - 1, - ) - - similarity_loss = self.similarity( - self.warped_image_A, image_B - ) + self.similarity(self.warped_image_B, image_A) - - if len(self.input_shape) - 2 == 3: - Iepsilon = ( - self.identity_map - + 2 * torch.randn(*self.identity_map.shape).to(config.device) - / self.identity_map.shape[-1] - )[:, :, ::2, ::2, ::2] - elif len(self.input_shape) - 2 == 2: - Iepsilon = ( - self.identity_map - + 2 * torch.randn(*self.identity_map.shape).to(config.device) - / self.identity_map.shape[-1] - )[:, :, ::2, ::2] - - # compute squared Frobenius of Jacobian of icon error + phi_AB = self.regis_net(image_A, image_B)["phi_AB"] - direction_losses = [] + phi_AB_vectorfield = phi_AB(self.identity_map) - approximate_Iepsilon = self.phi_AB(self.phi_BA(Iepsilon)) + similarity_AB = self.compute_similarity( image_A, image_B, phi_AB_vectorfield) - inverse_consistency_error = Iepsilon - approximate_Iepsilon + similarity_loss = 2 * similarity_AB["similarity_loss"] + regularization_loss = compute_regularizer(self, phi_AB_vectorfield) - delta = 0.001 + all_loss = self.lmbda * regularization_loss + similarity_loss - if len(self.identity_map.shape) == 4: - dx = torch.Tensor([[[[delta]], [[0.0]]]]).to(config.device) - dy = torch.Tensor([[[[0.0]], [[delta]]]]).to(config.device) - direction_vectors = (dx, dy) + negative_jacobian_voxels = flips(phi_AB_vectorfield) - elif len(self.identity_map.shape) == 5: - dx = torch.Tensor([[[[[delta]]], [[[0.0]]], [[[0.0]]]]]).to(config.device) - dy = torch.Tensor([[[[[0.0]]], [[[delta]]], [[[0.0]]]]]).to(config.device) - dz = torch.Tensor([[[[0.0]]], [[[0.0]]], [[[delta]]]]).to(config.device) - direction_vectors = (dx, dy, dz) - elif len(self.identity_map.shape) == 3: - dx = torch.Tensor([[[delta]]]).to(config.device) - direction_vectors = (dx,) - - for d in direction_vectors: - approximate_Iepsilon_d = self.phi_AB(self.phi_BA(Iepsilon + d)) - inverse_consistency_error_d = Iepsilon + d - approximate_Iepsilon_d - grad_d_icon_error = ( - inverse_consistency_error - inverse_consistency_error_d - ) / delta - direction_losses.append(torch.mean(grad_d_icon_error**2)) - - inverse_consistency_loss = sum(direction_losses) - - all_loss = self.lmbda * inverse_consistency_loss + similarity_loss - - transform_magnitude = torch.mean( - (self.identity_map - self.phi_AB_vectorfield) ** 2 - ) - return ICONLoss( - all_loss, - inverse_consistency_loss, - similarity_loss, - transform_magnitude, - flips(self.phi_BA_vectorfield), - ) + return {"all_loss": all_loss, "regularization_loss": regularization_loss, "similarity_loss": similarity_loss, "phi_AB": phi_AB, "warped_image_A": similarity_AB["warped_image_A"], "negative_jacobian_voxels":negative_jacobian_voxels} - -class BendingEnergy(network_wrappers.RegistrationModule): - def __init__(self, network, similarity, lmbda): - super().__init__() - - self.regis_net = network - self.lmbda = lmbda - self.similarity = similarity - def compute_bending_energy_loss(self, phi_AB_vectorfield): +class BendingEnergy(OneWayRegularizer): + def compute_regularizer(self, phi_AB_vectorfield): # dxdx = [f[x+h, y] + f[x-h, y] - 2 * f[x, y]]/(h**2) # dxdy = [f[x+h, y+h] + f[x-h, y-h] - f[x+h, y-h] - f[x-h, y+h]]/(4*h**2) # BE_2d = |dxdx| + |dydy| + 2 * |dxdy| - # psudo code: BE_2d = [torch.mean(dxdx**2) + torch.mean(dydy**2) + 2 * torch.mean(dxdy**2)]/4.0 + # pseudo code: BE_2d = [torch.mean(dxdx**2) + torch.mean(dydy**2) + 2 * torch.mean(dxdy**2)]/4.0 # BE_3d = |dxdx| + |dydy| + |dzdz| + 2 * |dxdy| + 2 * |dydz| + 2 * |dxdz| if len(self.identity_map.shape) == 3: @@ -404,75 +222,10 @@ def compute_bending_energy_loss(self, phi_AB_vectorfield): return bending_energy - def compute_similarity_measure(self, phi_AB_vectorfield, image_A, image_B): - - if getattr(self.similarity, "isInterpolated", False): - # tag images during warping so that the similarity measure - # can use information about whether a sample is interpolated - # or extrapolated - inbounds_tag = torch.zeros([image_A.shape[0]] + [1] + list(image_A.shape[2:]), device=image_A.device) - if len(self.input_shape) - 2 == 3: - inbounds_tag[:, :, 1:-1, 1:-1, 1:-1] = 1.0 - elif len(self.input_shape) - 2 == 2: - inbounds_tag[:, :, 1:-1, 1:-1] = 1.0 - else: - inbounds_tag[:, :, 1:-1] = 1.0 - else: - inbounds_tag = None - - self.warped_image_A = self.as_function( - torch.cat([image_A, inbounds_tag], axis=1) if inbounds_tag is not None else image_A - )(phi_AB_vectorfield) - - similarity_loss = self.similarity( - self.warped_image_A, image_B - ) - return similarity_loss - - def forward(self, image_A, image_B) -> ICONLoss: - assert self.identity_map.shape[2:] == image_A.shape[2:] - assert self.identity_map.shape[2:] == image_B.shape[2:] - - # Tag used elsewhere for optimization. - # Must be set at beginning of forward b/c not preserved by .cuda() etc - self.identity_map.isIdentity = True - - self.phi_AB = self.regis_net(image_A, image_B) - self.phi_AB_vectorfield = self.phi_AB(self.identity_map) - - similarity_loss = 2 * self.compute_similarity_measure( - self.phi_AB_vectorfield, image_A, image_B - ) - - bending_energy_loss = self.compute_bending_energy_loss( - self.phi_AB_vectorfield - ) - all_loss = self.lmbda * bending_energy_loss + similarity_loss - - transform_magnitude = torch.mean( - (self.identity_map - self.phi_AB_vectorfield) ** 2 - ) - return BendingLoss( - all_loss, - bending_energy_loss, - similarity_loss, - transform_magnitude, - flips(self.phi_AB_vectorfield), - ) - - def prepare_for_viz(self, image_A, image_B): - self.phi_AB = self.regis_net(image_A, image_B) - self.phi_AB_vectorfield = self.phi_AB(self.identity_map) - self.phi_BA = self.regis_net(image_B, image_A) - self.phi_BA_vectorfield = self.phi_BA(self.identity_map) - - self.warped_image_A = self.as_function(image_A)(self.phi_AB_vectorfield) - self.warped_image_B = self.as_function(image_B)(self.phi_BA_vectorfield) - -class Diffusion(BendingEnergyNet): - def compute_bending_energy_loss(self, phi_AB_vectorfield): +class Diffusion(OneWayRegularizer): + def compute_regularizer(self, phi_AB_vectorfield): phi_AB_vectorfield = self.identity_map - phi_AB_vectorfield if len(self.identity_map.shape) == 3: bending_energy = torch.mean(( @@ -503,3 +256,66 @@ def compute_bending_energy_loss(self, phi_AB_vectorfield): return bending_energy * self.identity_map.shape[2] **2 +class VelocityFieldDiffusion(Diffusion): + def forward(self, image_A, image_B): + assert self.identity_map.shape[2:] == image_A.shape[2:] + assert self.identity_map.shape[2:] == image_B.shape[2:] + + # Tag used elsewhere for optimization. + # Must be set at beginning of forward b/c not preserved by .cuda() etc + self.identity_map.isIdentity = True + + phi_AB_dict = self.regis_net(image_A, image_B) + phi_AB = phi_AB_dict["phi_AB"] + + + phi_AB_vectorfield = phi_AB(self.identity_map) + + similarity_AB = self.compute_similarity( image_A, image_B, phi_AB_vectorfield) + + similarity_loss = 2 * similarity_AB["similarity_loss"] + + velocity_fields = phi_AB["velocity_fields"] + regularization_loss = 0 + for v in velocity_fields: + regularization_loss+ = compute_regularizer(self, phi_AB_vectorfield) + + all_loss = self.lmbda * regularization_loss + similarity_loss + + negative_jacobian_voxels = flips(phi_AB_vectorfield) + + return {"all_loss": all_loss, "regularization_loss": regularization_loss, "similarity_loss": similarity_loss, "phi_AB": phi_AB, "warped_image_A": similarity_AB["warped_image_A"], "negative_jacobian_voxels":negative_jacobian_voxels} + +class VelocityFieldBendingEnergy(BendingEnergy): + def forward(self, image_A, image_B): + assert self.identity_map.shape[2:] == image_A.shape[2:] + assert self.identity_map.shape[2:] == image_B.shape[2:] + + # Tag used elsewhere for optimization. + # Must be set at beginning of forward b/c not preserved by .cuda() etc + self.identity_map.isIdentity = True + + phi_AB_dict = self.regis_net(image_A, image_B) + phi_AB = phi_AB_dict["phi_AB"] + + + phi_AB_vectorfield = phi_AB(self.identity_map) + + similarity_AB = self.compute_similarity( image_A, image_B, phi_AB_vectorfield) + + similarity_loss = 2 * similarity_AB["similarity_loss"] + + velocity_fields = phi_AB["velocity_fields"] + regularization_loss = 0 + for v in velocity_fields: + regularization_loss+ = compute_regularizer(self, phi_AB_vectorfield) + + all_loss = self.lmbda * regularization_loss + similarity_loss + + negative_jacobian_voxels = flips(phi_AB_vectorfield) + + return {"all_loss": all_loss, "regularization_loss": regularization_loss, "similarity_loss": similarity_loss, "phi_AB": phi_AB, "warped_image_A": similarity_AB["warped_image_A"], "negative_jacobian_voxels":negative_jacobian_voxels} + + + + diff --git a/ICON/icon_registration/network_wrappers.py b/ICON/icon_registration/network_wrappers.py index 1ed8e40c..de9213aa 100644 --- a/ICON/icon_registration/network_wrappers.py +++ b/ICON/icon_registration/network_wrappers.py @@ -2,6 +2,7 @@ import torch import torch.nn.functional as F from torch import nn +from icon_registration.registration_module import RegistrationModule class DisplacementField(RegistrationModule): diff --git a/ICON/test/test_losses.py b/ICON/test/test_losses.py deleted file mode 100644 index 2fed0f38..00000000 --- a/ICON/test/test_losses.py +++ /dev/null @@ -1,127 +0,0 @@ -import torch -from icon_registration.losses import AdaptiveNCC, normalize -import math -import numbers -import torch -from torch import nn - -import unittest - -class TestLosses(unittest.TestCase): - def test_adaptive_ncc(self): - a = torch.rand(1,1,64,64,64) - b = torch.rand(1,1,64,64,64) - - sim = AdaptiveNCC() - l = sim(a, b) - sim_origin = adaptive_ncc() - l_origin = sim_origin(a, b) - - self.assertAlmostEqual(l.item(), l_origin.item(), places=5) - -class GaussianSmoothing(nn.Module): - """ - Apply gaussian smoothing on a - 1d, 2d or 3d tensor. Filtering is performed seperately for each channel - in the input using a depthwise convolution. - Arguments: - channels (int, sequence): Number of channels of the input tensors. Output will - have this number of channels as well. - kernel_size (int, sequence): Size of the gaussian kernel. - sigma (float, sequence): Standard deviation of the gaussian kernel. - dim (int, optional): The number of dimensions of the data. - Default value is 2 (spatial). - """ - - def __init__(self, channels, kernel_size, sigma, dim=2): - super(GaussianSmoothing, self).__init__() - if isinstance(kernel_size, numbers.Number): - kernel_size = [kernel_size] * dim - if isinstance(sigma, numbers.Number): - sigma = [sigma] * dim - - # The gaussian kernel is the product of the - # gaussian function of each dimension. - kernel = 1 - meshgrids = torch.meshgrid( - [torch.arange(size, dtype=torch.float32) for size in kernel_size] - ) - for size, std, mgrid in zip(kernel_size, sigma, meshgrids): - mean = (size - 1) / 2 - kernel *= ( - 1 - / (std * math.sqrt(2 * math.pi)) - * torch.exp(-(((mgrid - mean) / std) ** 2) / 2) - ) - - # Make sure sum of values in gaussian kernel equals 1. - kernel = kernel / torch.sum(kernel) - - # Reshape to depthwise convolutional weight - kernel = kernel.view(1, 1, *kernel.size()) - kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) - - self.register_buffer("weight", kernel) - self.groups = channels - - if dim == 1: - self.conv = nn.functional.conv1d - elif dim == 2: - self.conv = nn.functional.conv2d - elif dim == 3: - self.conv = nn.functional.conv3d - else: - raise RuntimeError( - "Only 1, 2 and 3 dimensions are supported. Received {}.".format(dim) - ) - - def forward(self, input): - """ - Apply gaussian filter to input. - Arguments: - input (torch.Tensor): Input to apply gaussian filter on. - Returns: - filtered (torch.Tensor): Filtered output. - """ - return self.conv( - input, - weight=self.weight, - groups=self.groups, - padding=int(self.weight.shape[2] / 2), - ) - -def adaptive_ncc( - level=4, threshold=0.1, gamma=1.5, smoother=GaussianSmoothing(1, 5, 2, 3)): - def _nccBeforeMean(image_A, image_B): - A = normalize(image_A[:, :1]) - B = normalize(image_B) - res = torch.mean(A * B, dim=(1, 2, 3, 4)) - return 1 - res - - def _sim(x, y): - sims = [_nccBeforeMean(x, y)] - for i in range(level): - if i == 0: - sims.append(_nccBeforeMean(smoother(x), smoother(y))) - else: - sims.append( - _nccBeforeMean( - smoother(nn.functional.avg_pool3d(x, 2**i)), - smoother(nn.functional.avg_pool3d(y, 2**i)), - ) - ) - - sim_loss = sims[0] + 0 - lamb_ = 1.0 - for i in range(1, len(sims)): - lamb = torch.clamp( - sims[i].detach() / (threshold / (gamma ** (len(sims) - i))), 0, 1 - ) - sim_loss = lamb * sims[i] + (1 - lamb) * sim_loss - lamb_ *= 1 - lamb - - return torch.mean(sim_loss) - - return _sim - - \ No newline at end of file From e31b8dfe3551ac89ec156e3bf90f04d5908fd400 Mon Sep 17 00:00:00 2001 From: Hastings Greer Date: Tue, 19 Dec 2023 15:24:05 -0500 Subject: [PATCH 05/15] black pass --- ICON/icon_registration/__init__.py | 2 +- ICON/icon_registration/data.py | 169 ++++++++--- ICON/icon_registration/itk_wrapper.py | 18 +- ICON/icon_registration/losses.py | 265 +++++++++++------- ICON/icon_registration/network_wrappers.py | 91 +++--- ICON/icon_registration/networks.py | 2 +- .../pretrained_models/HCP_brain.py | 10 +- .../pretrained_models/lung_ct.py | 38 ++- ICON/icon_registration/registration_module.py | 18 +- ICON/icon_registration/similarity.py | 81 ++++-- ICON/icon_registration/test_utils.py | 2 +- ICON/icon_registration/train.py | 26 +- ICON/test/test_brain_itk.py | 18 +- ICON/test/test_knee_itk.py | 8 +- ICON/test/test_lung_itk.py | 5 +- 15 files changed, 499 insertions(+), 254 deletions(-) diff --git a/ICON/icon_registration/__init__.py b/ICON/icon_registration/__init__.py index b9488792..ee64a6f3 100644 --- a/ICON/icon_registration/__init__.py +++ b/ICON/icon_registration/__init__.py @@ -9,7 +9,7 @@ ssd, SSDOnlyInterpolated, SSD, - NCC + NCC, ) from icon_registration.network_wrappers import ( DownsampleRegistration, diff --git a/ICON/icon_registration/data.py b/ICON/icon_registration/data.py index 8504bf7c..6905232b 100644 --- a/ICON/icon_registration/data.py +++ b/ICON/icon_registration/data.py @@ -107,7 +107,6 @@ def get_dataset_retina( import elasticdeform import hub except: - raise Exception( """the retina dataset requires the dependencies hub and elasticdeform. Try pip install hub elasticdeform""" @@ -120,7 +119,6 @@ def get_dataset_retina( if os.path.exists(ds_name): augmented_ds1_tensor, augmented_ds2_tensor = torch.load(ds_name) else: - res = [] for batch in hub.load("hub://activeloop/drive-train").pytorch( num_workers=0, batch_size=4, shuffle=False @@ -257,103 +255,172 @@ def get_knees_dataset(): return brains, medbrains -def get_copdgene_dataset(data_folder, cache_folder="./data_cache", lung_only=True, downscale=2): - ''' + +def get_copdgene_dataset( + data_folder, cache_folder="./data_cache", lung_only=True, downscale=2 +): + """ This function load the preprocessed COPDGene train set. - ''' + """ import os + def process(iA, downscale, clamp=[-1000, 0], isSeg=False): iA = iA[None, None, :, :, :] - #SI flip + # SI flip iA = torch.flip(iA, dims=(2,)) if isSeg: iA = iA.float() iA = torch.nn.functional.max_pool3d(iA, downscale) - iA[iA>0] = 1 + iA[iA > 0] = 1 else: iA = torch.clip(iA, clamp[0], clamp[1]) + clamp[0] - #TODO: For compatibility to the processed dataset(ranges between -1 to 0) used in paper, we subtract -1 here. + # TODO: For compatibility to the processed dataset(ranges between -1 to 0) used in paper, we subtract -1 here. # Should remove -1 later. - iA = iA / torch.max(iA) - 1. + iA = iA / torch.max(iA) - 1.0 iA = torch.nn.functional.avg_pool3d(iA, downscale) return iA cache_name = f"{cache_folder}/lungs_train_{downscale}xdown_scaled" if os.path.exists(cache_name): - imgs = torch.load(cache_name, map_location='cpu') + imgs = torch.load(cache_name, map_location="cpu") if lung_only: try: - masks = torch.load(f"{cache_folder}/lungs_seg_train_{downscale}xdown_scaled", map_location='cpu') + masks = torch.load( + f"{cache_folder}/lungs_seg_train_{downscale}xdown_scaled", + map_location="cpu", + ) except FileNotFoundError: print("Segmentation data not found.") else: import itk import glob + with open(f"{data_folder}/splits/train.txt") as f: pair_paths = f.readlines() imgs = [] masks = [] for name in tqdm.tqdm(list(iter(pair_paths))[:]): - name = name[:-1] # remove newline - - image_insp = torch.tensor(np.asarray(itk.imread(glob.glob(f"{data_folder} /{name}/{name}_INSP_STD*_COPD_img.nii.gz")[0]))) - image_exp= torch.tensor(np.asarray(itk.imread(glob.glob(f"{data_folder} /{name}/{name}_EXP_STD*_COPD_img.nii.gz")[0]))) + name = name[:-1] # remove newline + + image_insp = torch.tensor( + np.asarray( + itk.imread( + glob.glob( + f"{data_folder} /{name}/{name}_INSP_STD*_COPD_img.nii.gz" + )[0] + ) + ) + ) + image_exp = torch.tensor( + np.asarray( + itk.imread( + glob.glob( + f"{data_folder} /{name}/{name}_EXP_STD*_COPD_img.nii.gz" + )[0] + ) + ) + ) imgs.append((process(image_insp), process(image_exp))) - seg_insp = torch.tensor(np.asarray(itk.imread(glob.glob(f"{data_folder} /{name}/{name}_INSP_STD*_COPD_label.nii.gz")[0]))) - seg_exp= torch.tensor(np.asarray(itk.imread(glob.glob(f"{data_folder} /{name}/{name}_EXP_STD*_COPD_label.nii.gz")[0]))) + seg_insp = torch.tensor( + np.asarray( + itk.imread( + glob.glob( + f"{data_folder} /{name}/{name}_INSP_STD*_COPD_label.nii.gz" + )[0] + ) + ) + ) + seg_exp = torch.tensor( + np.asarray( + itk.imread( + glob.glob( + f"{data_folder} /{name}/{name}_EXP_STD*_COPD_label.nii.gz" + )[0] + ) + ) + ) masks.append((process(seg_insp, True), process(seg_exp, True))) torch.save(imgs, f"{cache_folder}/lungs_train_{downscale}xdown_scaled") torch.save(masks, f"{cache_folder}/lungs_seg_train_{downscale}xdown_scaled") - + if lung_only: - imgs = torch.cat([(torch.cat(d, 1)+1)*torch.cat(m, 1) for d,m in zip(imgs, masks)], dim=0) + imgs = torch.cat( + [(torch.cat(d, 1) + 1) * torch.cat(m, 1) for d, m in zip(imgs, masks)], + dim=0, + ) else: - imgs = torch.cat([torch.cat(d, 1)+1 for d in imgs], dim=0) + imgs = torch.cat([torch.cat(d, 1) + 1 for d in imgs], dim=0) return torch.utils.data.TensorDataset(imgs) -def get_learn2reg_AbdomenCTCT_dataset(data_folder, cache_folder="./data_cache", clamp=[-1000,0], downscale=1): - ''' + +def get_learn2reg_AbdomenCTCT_dataset( + data_folder, cache_folder="./data_cache", clamp=[-1000, 0], downscale=1 +): + """ This function will return the training dataset of AbdomenCTCT registration task in learn2reg. - ''' + """ # Check whether we have cached the dataset import os - cache_name = f"{cache_folder}/learn2reg_abdomenctct_train_set_clamp{clamp}scale{downscale}" + cache_name = ( + f"{cache_folder}/learn2reg_abdomenctct_train_set_clamp{clamp}scale{downscale}" + ) if os.path.exists(cache_name): imgs = torch.load(cache_name) else: import json import itk import glob - with open(f"{data_folder}/AbdomenCTCT_dataset.json", 'r') as data_info: + + with open(f"{data_folder}/AbdomenCTCT_dataset.json", "r") as data_info: data_info = json.loads(data_info.read()) - train_cases = [c["image"].split('/')[-1].split('.')[0] for c in data_info["training"]] - imgs = [np.asarray(itk.imread(glob.glob(data_folder + "/imagesTr/" + i + ".nii.gz")[0])) for i in train_cases] - + train_cases = [ + c["image"].split("/")[-1].split(".")[0] for c in data_info["training"] + ] + imgs = [ + np.asarray( + itk.imread(glob.glob(data_folder + "/imagesTr/" + i + ".nii.gz")[0]) + ) + for i in train_cases + ] + imgs = torch.Tensor(np.expand_dims(np.array(imgs), axis=1)).float() - imgs = (torch.clamp(imgs, clamp[0], clamp[1]) - clamp[0])/(clamp[1] - clamp[0]) + imgs = (torch.clamp(imgs, clamp[0], clamp[1]) - clamp[0]) / ( + clamp[1] - clamp[0] + ) # Cache the data if not os.path.exists(cache_folder): os.makedirs(cache_folder) torch.save(imgs, cache_name) - + # Scale down the image if downscale > 1: imgs = F.avg_pool3d(imgs, downscale) return torch.utils.data.TensorDataset(imgs) -def get_learn2reg_lungCT_dataset(data_folder, cache_folder="./data_cache", lung_only=True, clamp=[-1000,0], downscale=1): - ''' + +def get_learn2reg_lungCT_dataset( + data_folder, + cache_folder="./data_cache", + lung_only=True, + clamp=[-1000, 0], + downscale=1, +): + """ This function will return the training dataset of LungCT registration task in learn2reg. - ''' + """ import os - cache_name = f"{cache_folder}/learn2reg_lung_train_set_lung_only" if lung_only else f"{cache_folder}/learn2reg_lung_train_set" + cache_name = ( + f"{cache_folder}/learn2reg_lung_train_set_lung_only" + if lung_only + else f"{cache_folder}/learn2reg_lung_train_set" + ) cache_name += f"_clamp{clamp}scale{downscale}" if os.path.exists(cache_name): imgs = torch.load(cache_name) @@ -361,25 +428,45 @@ def get_learn2reg_lungCT_dataset(data_folder, cache_folder="./data_cache", lung_ import json import itk import glob - with open(f"{data_folder}/NLST_dataset.json", 'r') as data_info: + + with open(f"{data_folder}/NLST_dataset.json", "r") as data_info: data_info = json.loads(data_info.read()) - train_pairs = [[p['fixed'].split('/')[-1], p['moving'].split('/')[-1]] for p in data_info["training_paired_images"]] + train_pairs = [ + [p["fixed"].split("/")[-1], p["moving"].split("/")[-1]] + for p in data_info["training_paired_images"] + ] imgs = [] for p in train_pairs: - img = np.array([np.asarray(itk.imread(glob.glob(data_folder + "/imagesTr/" + i)[0])) for i in p]) + img = np.array( + [ + np.asarray(itk.imread(glob.glob(data_folder + "/imagesTr/" + i)[0])) + for i in p + ] + ) if lung_only: - mask = np.array([np.asarray(itk.imread(glob.glob(data_folder + "/" + "/masksTr/" + i)[0])) for i in p]) + mask = np.array( + [ + np.asarray( + itk.imread( + glob.glob(data_folder + "/" + "/masksTr/" + i)[0] + ) + ) + for i in p + ] + ) img = img * mask + clamp[0] * (1 - mask) imgs.append(img) - + imgs = torch.Tensor(np.array(imgs)).float() - imgs = (torch.clamp(imgs, clamp[0], clamp[1]) - clamp[0])/(clamp[1] - clamp[0]) + imgs = (torch.clamp(imgs, clamp[0], clamp[1]) - clamp[0]) / ( + clamp[1] - clamp[0] + ) # Cache the data if not os.path.exists(cache_folder): os.makedirs(cache_folder) torch.save(imgs, cache_name) - + # Scale down the image if downscale > 1: imgs = F.avg_pool3d(imgs, downscale) diff --git a/ICON/icon_registration/itk_wrapper.py b/ICON/icon_registration/itk_wrapper.py index 422d91e0..5559cb6b 100644 --- a/ICON/icon_registration/itk_wrapper.py +++ b/ICON/icon_registration/itk_wrapper.py @@ -27,7 +27,6 @@ def finetune_execute(model, image_A, image_B, steps): def register_pair( model, image_A, image_B, finetune_steps=None, return_artifacts=False ) -> "(itk.CompositeTransform, itk.CompositeTransform)": - assert isinstance(image_A, itk.Image) assert isinstance(image_B, itk.Image) @@ -37,8 +36,8 @@ def register_pair( A_npy = np.array(image_A) B_npy = np.array(image_B) - assert(np.max(A_npy) != np.min(A_npy)) - assert(np.max(B_npy) != np.min(B_npy)) + assert np.max(A_npy) != np.min(A_npy) + assert np.max(B_npy) != np.min(B_npy) # turn images into torch Tensors: add feature and batch dimensions (each of length 1) A_trch = torch.Tensor(A_npy).to(config.device)[None, None] B_trch = torch.Tensor(B_npy).to(config.device)[None, None] @@ -81,11 +80,13 @@ def register_pair( else: return itk_transforms + (to_floats(loss),) + def register_pair_with_multimodalities( model, image_A: list, image_B: list, finetune_steps=None, return_artifacts=False ) -> "(itk.CompositeTransform, itk.CompositeTransform)": - - assert len(image_A) == len(image_B), "image_A and image_B should have the same number of modalities." + assert len(image_A) == len( + image_B + ), "image_A and image_B should have the same number of modalities." # send model to cpu or gpu depending on config- auto detects capability model.to(config.device) @@ -98,8 +99,8 @@ def register_pair_with_multimodalities( A_npy.append(np.array(image_a)) B_npy.append(np.array(image_b)) - assert(np.max(A_npy[-1]) != np.min(A_npy[-1])) - assert(np.max(B_npy[-1]) != np.min(B_npy[-1])) + assert np.max(A_npy[-1]) != np.min(A_npy[-1]) + assert np.max(B_npy[-1]) != np.min(B_npy[-1]) # turn images into torch Tensors: add batch dimensions (each of length 1) A_trch = torch.Tensor(np.array(A_npy)).to(config.device)[None] @@ -146,7 +147,6 @@ def register_pair_with_multimodalities( def create_itk_transform(phi, ident, image_A, image_B) -> "itk.CompositeTransform": - # itk.DeformationFieldTransform expects a displacement field, so we subtract off the identity map. disp = (phi - ident)[0].cpu() @@ -194,7 +194,6 @@ def create_itk_transform(phi, ident, image_A, image_B) -> "itk.CompositeTransfor def resampling_transform(image, shape): - imageType = itk.template(image)[0][itk.template(image)[1]] dummy_image = itk.image_from_array( @@ -225,7 +224,6 @@ def resampling_transform(image, shape): input_shape = image.GetLargestPossibleRegion().GetSize() for i in range(len(shape)): - m_a[i, i] = image.GetSpacing()[i] * (input_shape[i] / shape[i]) m_a = itk.array_from_matrix(image.GetDirection()) @ m_a diff --git a/ICON/icon_registration/losses.py b/ICON/icon_registration/losses.py index 7587524c..f4fbdbb8 100644 --- a/ICON/icon_registration/losses.py +++ b/ICON/icon_registration/losses.py @@ -8,18 +8,23 @@ import registration_module + class Loss(registration_module.RegistrationModule): def __init__(self, network, similarity, lmbda): super().__init__() self.regis_net = network self.lmbda = lmbda self.similarity = similarity + def compute_similarity(self, image_A, image_B, phi_AB_vectorfield): if getattr(self.similarity, "isInterpolated", False): # tag images during warping so that the similarity measure # can use information about whether a sample is interpolated # or extrapolated - inbounds_tag = torch.zeros([image_A.shape[0]] + [1] + list(image_A.shape[2:]), device=image_A.device) + inbounds_tag = torch.zeros( + [image_A.shape[0]] + [1] + list(image_A.shape[2:]), + device=image_A.device, + ) if len(self.input_shape) - 2 == 3: inbounds_tag[:, :, 1:-1, 1:-1, 1:-1] = 1.0 elif len(self.input_shape) - 2 == 2: @@ -30,13 +35,15 @@ def compute_similarity(self, image_A, image_B, phi_AB_vectorfield): inbounds_tag = None warped_image_A = self.as_function( - torch.cat([image_A, inbounds_tag], axis=1) if inbounds_tag is not None else image_A)( + torch.cat([image_A, inbounds_tag], axis=1) + if inbounds_tag is not None + else image_A + )( self.phi_AB_vectorfield, ) - similarity_loss = self.similarity( - warped_image_A, image_B - ) - return {"similarity_loss":similarity_loss, "warped_image_A":warped_image_A} + similarity_loss = self.similarity(warped_image_A, image_B) + return {"similarity_loss": similarity_loss, "warped_image_A": warped_image_A} + class TwoWayRegularizer(Loss): def forward(self, image_A, image_B): @@ -53,25 +60,34 @@ def forward(self, image_A, image_B): phi_AB_vectorfield = phi_AB(self.identity_map) phi_BA_vectorfield = phi_BA(self.identity_map) - similarity_AB = self.compute_similarity( image_A, image_B, phi_AB_vectorfield) - similarity_BA = self.compute_similarity( image_B, image_A, phi_BA_vectorfield) + similarity_AB = self.compute_similarity(image_A, image_B, phi_AB_vectorfield) + similarity_BA = self.compute_similarity(image_B, image_A, phi_BA_vectorfield) - similarity_loss = similarity_AB["similarity_loss"] + similarity_BA["similarity_loss"] + similarity_loss = ( + similarity_AB["similarity_loss"] + similarity_BA["similarity_loss"] + ) regularization_loss = compute_regularizer(self, phi_AB, phi_BA) all_loss = self.lmbda * gradient_inverse_consistency_loss + similarity_loss negative_jacobian_voxels = flips(phi_BA_vectorfield) - return {"all_loss": all_loss, "regularization_loss": inverse_consistency_loss, "similarity_loss": similarity_loss, "phi_AB": phi_AB, "phi_BA": phi_BA, "warped_image_A": similarity_AB["warped_image_A"], "warped_image_B": similiarity_BA["warped_image_A"], "negative_jacobian_voxels":negative_jacobian_voxels} + return { + "all_loss": all_loss, + "regularization_loss": inverse_consistency_loss, + "similarity_loss": similarity_loss, + "phi_AB": phi_AB, + "phi_BA": phi_BA, + "warped_image_A": similarity_AB["warped_image_A"], + "warped_image_B": similiarity_BA["warped_image_A"], + "negative_jacobian_voxels": negative_jacobian_voxels, + } class ICON(TwoWayRegularizer): - def compute_regularizer(self, phi_AB, phi_BA): - Iepsilon = ( - self.identity_map - + torch.randn(*self.identity_map.shape).to(image_A.device) + Iepsilon = self.identity_map + torch.randn(*self.identity_map.shape).to( + image_A.device ) approximate_Iepsilon1 = self.phi_AB(self.phi_BA(Iepsilon)) @@ -83,19 +99,17 @@ def compute_regularizer(self, phi_AB, phi_BA): ) + torch.mean((Iepsilon - approximate_Iepsilon2) ** 2) inverse_consistency_loss /= self.input_shape[2] ** 2 - - return inverse_consistency_loss + return inverse_consistency_loss class GradICON(TwoWayRegularizer): def compute_regularizer(self, phi_AB, phi_BA): - Iepsilon = ( - self.identity_map - + torch.randn(*self.identity_map.shape).to(self.identity_map.device) + Iepsilon = self.identity_map + torch.randn(*self.identity_map.shape).to( + self.identity_map.device ) if len(self.input_shape) - 2 == 3: - Iepsilon =Iepsilon [:, :, ::2, ::2, ::2] + Iepsilon = Iepsilon[:, :, ::2, ::2, ::2] elif len(self.input_shape) - 2 == 2: Iepsilon = Iepsilon[:, :, ::2, ::2] @@ -141,6 +155,7 @@ def compute_regularizer(self, phi_AB, phi_BA): return gradient_inverse_consistency_loss + class OneWayRegularizer(Loss): def forward(self, image_A, image_B): assert self.identity_map.shape[2:] == image_A.shape[2:] @@ -154,107 +169,148 @@ def forward(self, image_A, image_B): phi_AB_vectorfield = phi_AB(self.identity_map) - similarity_AB = self.compute_similarity( image_A, image_B, phi_AB_vectorfield) + similarity_AB = self.compute_similarity(image_A, image_B, phi_AB_vectorfield) - similarity_loss = 2 * similarity_AB["similarity_loss"] + similarity_loss = 2 * similarity_AB["similarity_loss"] regularization_loss = compute_regularizer(self, phi_AB_vectorfield) all_loss = self.lmbda * regularization_loss + similarity_loss negative_jacobian_voxels = flips(phi_AB_vectorfield) - return {"all_loss": all_loss, "regularization_loss": regularization_loss, "similarity_loss": similarity_loss, "phi_AB": phi_AB, "warped_image_A": similarity_AB["warped_image_A"], "negative_jacobian_voxels":negative_jacobian_voxels} + return { + "all_loss": all_loss, + "regularization_loss": regularization_loss, + "similarity_loss": similarity_loss, + "phi_AB": phi_AB, + "warped_image_A": similarity_AB["warped_image_A"], + "negative_jacobian_voxels": negative_jacobian_voxels, + } - class BendingEnergy(OneWayRegularizer): def compute_regularizer(self, phi_AB_vectorfield): # dxdx = [f[x+h, y] + f[x-h, y] - 2 * f[x, y]]/(h**2) # dxdy = [f[x+h, y+h] + f[x-h, y-h] - f[x+h, y-h] - f[x-h, y+h]]/(4*h**2) # BE_2d = |dxdx| + |dydy| + 2 * |dxdy| - # pseudo code: BE_2d = [torch.mean(dxdx**2) + torch.mean(dydy**2) + 2 * torch.mean(dxdy**2)]/4.0 + # pseudo code: BE_2d = [torch.mean(dxdx**2) + torch.mean(dydy**2) + 2 * torch.mean(dxdy**2)]/4.0 # BE_3d = |dxdx| + |dydy| + |dzdz| + 2 * |dxdy| + 2 * |dydz| + 2 * |dxdz| - + if len(self.identity_map.shape) == 3: - dxdx = (phi_AB_vectorfield[:, :, 2:] - - 2*phi_AB_vectorfield[:, :, 1:-1] - + phi_AB_vectorfield[:, :, :-2]) / self.spacing[0]**2 - bending_energy = torch.mean((dxdx)**2) - + dxdx = ( + phi_AB_vectorfield[:, :, 2:] + - 2 * phi_AB_vectorfield[:, :, 1:-1] + + phi_AB_vectorfield[:, :, :-2] + ) / self.spacing[0] ** 2 + bending_energy = torch.mean((dxdx) ** 2) + elif len(self.identity_map.shape) == 4: - dxdx = (phi_AB_vectorfield[:, :, 2:] - - 2*phi_AB_vectorfield[:, :, 1:-1] - + phi_AB_vectorfield[:, :, :-2]) / self.spacing[0]**2 - dydy = (phi_AB_vectorfield[:, :, :, 2:] - - 2*phi_AB_vectorfield[:, :, :, 1:-1] - + phi_AB_vectorfield[:, :, :, :-2]) / self.spacing[1]**2 - dxdy = (phi_AB_vectorfield[:, :, 2:, 2:] - + phi_AB_vectorfield[:, :, :-2, :-2] + dxdx = ( + phi_AB_vectorfield[:, :, 2:] + - 2 * phi_AB_vectorfield[:, :, 1:-1] + + phi_AB_vectorfield[:, :, :-2] + ) / self.spacing[0] ** 2 + dydy = ( + phi_AB_vectorfield[:, :, :, 2:] + - 2 * phi_AB_vectorfield[:, :, :, 1:-1] + + phi_AB_vectorfield[:, :, :, :-2] + ) / self.spacing[1] ** 2 + dxdy = ( + phi_AB_vectorfield[:, :, 2:, 2:] + + phi_AB_vectorfield[:, :, :-2, :-2] - phi_AB_vectorfield[:, :, 2:, :-2] - - phi_AB_vectorfield[:, :, :-2, 2:]) / (4.0*self.spacing[0]*self.spacing[1]) - bending_energy = (torch.mean(dxdx**2) + torch.mean(dydy**2) + 2*torch.mean(dxdy**2)) / 4.0 + - phi_AB_vectorfield[:, :, :-2, 2:] + ) / (4.0 * self.spacing[0] * self.spacing[1]) + bending_energy = ( + torch.mean(dxdx**2) + + torch.mean(dydy**2) + + 2 * torch.mean(dxdy**2) + ) / 4.0 elif len(self.identity_map.shape) == 5: - dxdx = (phi_AB_vectorfield[:, :, 2:] - - 2*phi_AB_vectorfield[:, :, 1:-1] - + phi_AB_vectorfield[:, :, :-2]) / self.spacing[0]**2 - dydy = (phi_AB_vectorfield[:, :, :, 2:] - - 2*phi_AB_vectorfield[:, :, :, 1:-1] - + phi_AB_vectorfield[:, :, :, :-2]) / self.spacing[1]**2 - dzdz = (phi_AB_vectorfield[:, :, :, :, 2:] - - 2*phi_AB_vectorfield[:, :, :, :, 1:-1] - + phi_AB_vectorfield[:, :, :, :, :-2]) / self.spacing[2]**2 - dxdy = (phi_AB_vectorfield[:, :, 2:, 2:] - + phi_AB_vectorfield[:, :, :-2, :-2] + dxdx = ( + phi_AB_vectorfield[:, :, 2:] + - 2 * phi_AB_vectorfield[:, :, 1:-1] + + phi_AB_vectorfield[:, :, :-2] + ) / self.spacing[0] ** 2 + dydy = ( + phi_AB_vectorfield[:, :, :, 2:] + - 2 * phi_AB_vectorfield[:, :, :, 1:-1] + + phi_AB_vectorfield[:, :, :, :-2] + ) / self.spacing[1] ** 2 + dzdz = ( + phi_AB_vectorfield[:, :, :, :, 2:] + - 2 * phi_AB_vectorfield[:, :, :, :, 1:-1] + + phi_AB_vectorfield[:, :, :, :, :-2] + ) / self.spacing[2] ** 2 + dxdy = ( + phi_AB_vectorfield[:, :, 2:, 2:] + + phi_AB_vectorfield[:, :, :-2, :-2] - phi_AB_vectorfield[:, :, 2:, :-2] - - phi_AB_vectorfield[:, :, :-2, 2:]) / (4.0*self.spacing[0]*self.spacing[1]) - dydz = (phi_AB_vectorfield[:, :, :, 2:, 2:] - + phi_AB_vectorfield[:, :, :, :-2, :-2] + - phi_AB_vectorfield[:, :, :-2, 2:] + ) / (4.0 * self.spacing[0] * self.spacing[1]) + dydz = ( + phi_AB_vectorfield[:, :, :, 2:, 2:] + + phi_AB_vectorfield[:, :, :, :-2, :-2] - phi_AB_vectorfield[:, :, :, 2:, :-2] - - phi_AB_vectorfield[:, :, :, :-2, 2:]) / (4.0*self.spacing[1]*self.spacing[2]) - dxdz = (phi_AB_vectorfield[:, :, 2:, :, 2:] - + phi_AB_vectorfield[:, :, :-2, :, :-2] + - phi_AB_vectorfield[:, :, :, :-2, 2:] + ) / (4.0 * self.spacing[1] * self.spacing[2]) + dxdz = ( + phi_AB_vectorfield[:, :, 2:, :, 2:] + + phi_AB_vectorfield[:, :, :-2, :, :-2] - phi_AB_vectorfield[:, :, 2:, :, :-2] - - phi_AB_vectorfield[:, :, :-2, :, 2:]) / (4.0*self.spacing[0]*self.spacing[2]) - - bending_energy = ((dxdx**2).mean() + (dydy**2).mean() + (dzdz**2).mean() - + 2.*(dxdy**2).mean() + 2.*(dydz**2).mean() + 2.*(dxdz**2).mean()) / 9.0 - + - phi_AB_vectorfield[:, :, :-2, :, 2:] + ) / (4.0 * self.spacing[0] * self.spacing[2]) + + bending_energy = ( + (dxdx**2).mean() + + (dydy**2).mean() + + (dzdz**2).mean() + + 2.0 * (dxdy**2).mean() + + 2.0 * (dydz**2).mean() + + 2.0 * (dxdz**2).mean() + ) / 9.0 return bending_energy - class Diffusion(OneWayRegularizer): def compute_regularizer(self, phi_AB_vectorfield): phi_AB_vectorfield = self.identity_map - phi_AB_vectorfield if len(self.identity_map.shape) == 3: - bending_energy = torch.mean(( - - phi_AB_vectorfield[:, :, 1:] - + phi_AB_vectorfield[:, :, 1:-1] - )**2) + bending_energy = torch.mean( + (-phi_AB_vectorfield[:, :, 1:] + phi_AB_vectorfield[:, :, 1:-1]) ** 2 + ) elif len(self.identity_map.shape) == 4: - bending_energy = torch.mean(( - - phi_AB_vectorfield[:, :, 1:] - + phi_AB_vectorfield[:, :, :-1] - )**2) + torch.mean(( - - phi_AB_vectorfield[:, :, :, 1:] - + phi_AB_vectorfield[:, :, :, :-1] - )**2) + bending_energy = torch.mean( + (-phi_AB_vectorfield[:, :, 1:] + phi_AB_vectorfield[:, :, :-1]) ** 2 + ) + torch.mean( + (-phi_AB_vectorfield[:, :, :, 1:] + phi_AB_vectorfield[:, :, :, :-1]) + ** 2 + ) elif len(self.identity_map.shape) == 5: - bending_energy = torch.mean(( - - phi_AB_vectorfield[:, :, 1:] - + phi_AB_vectorfield[:, :, :-1] - )**2) + torch.mean(( - - phi_AB_vectorfield[:, :, :, 1:] - + phi_AB_vectorfield[:, :, :, :-1] - )**2) + torch.mean(( - - phi_AB_vectorfield[:, :, :, :, 1:] - + phi_AB_vectorfield[:, :, :, :, :-1] - )**2) + bending_energy = ( + torch.mean( + (-phi_AB_vectorfield[:, :, 1:] + phi_AB_vectorfield[:, :, :-1]) ** 2 + ) + + torch.mean( + ( + -phi_AB_vectorfield[:, :, :, 1:] + + phi_AB_vectorfield[:, :, :, :-1] + ) + ** 2 + ) + + torch.mean( + ( + -phi_AB_vectorfield[:, :, :, :, 1:] + + phi_AB_vectorfield[:, :, :, :, :-1] + ) + ** 2 + ) + ) + return bending_energy * self.identity_map.shape[2] ** 2 - return bending_energy * self.identity_map.shape[2] **2 class VelocityFieldDiffusion(Diffusion): def forward(self, image_A, image_B): @@ -268,23 +324,30 @@ def forward(self, image_A, image_B): phi_AB_dict = self.regis_net(image_A, image_B) phi_AB = phi_AB_dict["phi_AB"] - phi_AB_vectorfield = phi_AB(self.identity_map) - similarity_AB = self.compute_similarity( image_A, image_B, phi_AB_vectorfield) + similarity_AB = self.compute_similarity(image_A, image_B, phi_AB_vectorfield) - similarity_loss = 2 * similarity_AB["similarity_loss"] + similarity_loss = 2 * similarity_AB["similarity_loss"] velocity_fields = phi_AB["velocity_fields"] regularization_loss = 0 for v in velocity_fields: - regularization_loss+ = compute_regularizer(self, phi_AB_vectorfield) + regularization_loss += compute_regularizer(self, phi_AB_vectorfield) all_loss = self.lmbda * regularization_loss + similarity_loss negative_jacobian_voxels = flips(phi_AB_vectorfield) - return {"all_loss": all_loss, "regularization_loss": regularization_loss, "similarity_loss": similarity_loss, "phi_AB": phi_AB, "warped_image_A": similarity_AB["warped_image_A"], "negative_jacobian_voxels":negative_jacobian_voxels} + return { + "all_loss": all_loss, + "regularization_loss": regularization_loss, + "similarity_loss": similarity_loss, + "phi_AB": phi_AB, + "warped_image_A": similarity_AB["warped_image_A"], + "negative_jacobian_voxels": negative_jacobian_voxels, + } + class VelocityFieldBendingEnergy(BendingEnergy): def forward(self, image_A, image_B): @@ -298,24 +361,26 @@ def forward(self, image_A, image_B): phi_AB_dict = self.regis_net(image_A, image_B) phi_AB = phi_AB_dict["phi_AB"] - phi_AB_vectorfield = phi_AB(self.identity_map) - similarity_AB = self.compute_similarity( image_A, image_B, phi_AB_vectorfield) + similarity_AB = self.compute_similarity(image_A, image_B, phi_AB_vectorfield) - similarity_loss = 2 * similarity_AB["similarity_loss"] + similarity_loss = 2 * similarity_AB["similarity_loss"] velocity_fields = phi_AB["velocity_fields"] regularization_loss = 0 for v in velocity_fields: - regularization_loss+ = compute_regularizer(self, phi_AB_vectorfield) + regularization_loss += compute_regularizer(self, phi_AB_vectorfield) all_loss = self.lmbda * regularization_loss + similarity_loss negative_jacobian_voxels = flips(phi_AB_vectorfield) - return {"all_loss": all_loss, "regularization_loss": regularization_loss, "similarity_loss": similarity_loss, "phi_AB": phi_AB, "warped_image_A": similarity_AB["warped_image_A"], "negative_jacobian_voxels":negative_jacobian_voxels} - - - - + return { + "all_loss": all_loss, + "regularization_loss": regularization_loss, + "similarity_loss": similarity_loss, + "phi_AB": phi_AB, + "warped_image_A": similarity_AB["warped_image_A"], + "negative_jacobian_voxels": negative_jacobian_voxels, + } diff --git a/ICON/icon_registration/network_wrappers.py b/ICON/icon_registration/network_wrappers.py index de9213aa..baa82d07 100644 --- a/ICON/icon_registration/network_wrappers.py +++ b/ICON/icon_registration/network_wrappers.py @@ -25,25 +25,31 @@ def transform(coordinates): return coordinates + displacement_field(coordinates) return {"phi_AB": transform} - + + class VelocityField(RegistrationModule): - def __init__(self, net): - super().__init__() - self.net = net - self.n_steps = 256 + def __init__(self, net): + super().__init__() + self.net = net + self.n_steps = 256 - def forward(self, image_A, image_B): - concatenated_images = torch.cat([image_A, image_B], axis=1) - velocity_field = self.net(concatenated_images) - velocityfield_delta = velocity_field / self.n_steps + def forward(self, image_A, image_B): + concatenated_images = torch.cat([image_A, image_B], axis=1) + velocity_field = self.net(concatenated_images) + velocityfield_delta = velocity_field / self.n_steps + + for _ in range(8): + velocityfield_delta = velocityfield_delta + self.as_function( + velocityfield_delta + )(velocityfield_delta + self.identity_map) - for _ in range(8): - velocityfield_delta = velocityfield_delta + self.as_function( - velocityfield_delta)(velocityfield_delta + self.identity_map) - def transform(coordinate_tensor): - coordinate_tensor = coordinate_tensor + self.as_function(velocityfield_delta)(coordinate_tensor) - return coordinate_tensor - return {"phi_AB": transform, "velocity_fields": [velocity_field]} + def transform(coordinate_tensor): + coordinate_tensor = coordinate_tensor + self.as_function( + velocityfield_delta + )(coordinate_tensor) + return coordinate_tensor + + return {"phi_AB": transform, "velocity_fields": [velocity_field]} def multiply_matrix_vectorfield(matrix, vectorfield): @@ -73,9 +79,15 @@ def transform(tensor_of_coordinates): shape = list(tensor_of_coordinates.shape) shape[1] = 1 coordinates_homogeneous = torch.cat( - [tensor_of_coordinates, torch.ones(shape, device=tensor_of_coordinates.device)], axis=1 + [ + tensor_of_coordinates, + torch.ones(shape, device=tensor_of_coordinates.device), + ], + axis=1, ) - return multiply_matrix_vectorfield(matrix_phi, coordinates_homogeneous)[:, :-1] + return multiply_matrix_vectorfield(matrix_phi, coordinates_homogeneous)[ + :, :-1 + ] return {"phi_AB": transform} @@ -96,11 +108,14 @@ def __init__(self, netPhi, netPsi): def forward(self, image_A, image_B): phi = self.netPhi(image_A, image_B) - psi = self.netPsi( - self.as_function(image_A)(phi(self.identity_map)), - image_B - )["phi_AB"] - result = {"phi_AB": lambda tensor_of_coordinates: phi["phi_AB"](psi["phi_AB"](tensor_of_coordinates))} + psi = self.netPsi(self.as_function(image_A)(phi(self.identity_map)), image_B)[ + "phi_AB" + ] + result = { + "phi_AB": lambda tensor_of_coordinates: phi["phi_AB"]( + psi["phi_AB"](tensor_of_coordinates) + ) + } regularization_loss = 0 if "regularization_loss" in phi: @@ -119,6 +134,7 @@ def forward(self, image_A, image_B): result["velocity_fields"] = regularization_loss return result + class Downsample(RegistrationModule): """ Perform registration using the wrapped RegistrationModule `net` @@ -142,7 +158,6 @@ def __init__(self, net, dimension): self.downscale_factor = 2 def forward(self, image_A, image_B): - image_A = self.avg_pool(image_A, 2, ceil_mode=True) image_B = self.avg_pool(image_B, 2, ceil_mode=True) result = self.net(image_A, image_B) @@ -155,7 +170,6 @@ def forward(self, image_A, image_B): result[key] = highres_phi return result - class InverseConsistentVelocityField(RegistrationModule): @@ -167,7 +181,9 @@ def __init__(self, net): def forward(self, image_A, image_B): concatenated_images_AB = torch.cat([image_A, image_B], axis=1) concatenated_images_BA = torch.cat([image_B, image_A], axis=1) - velocity_field = self.net(concatenated_images_AB) - self.net(concatenated_images_BA) + velocity_field = self.net(concatenated_images_AB) - self.net( + concatenated_images_BA + ) velocityfield_delta_ab = velocity_field / 2**self.n_steps velocityfield_delta_ba = -velocityfield_delta_ab @@ -193,8 +209,12 @@ def transform_BA(coordinate_tensor): )(coordinate_tensor) return coordinate_tensor + return { + "phi_AB": transform_AB, + "phi_BA": transform_BA, + "velocity_fields": [velocity_field], + } - return {"phi_AB": transform_AB, "phi_BA": transform_BA, "velocity_fields":[velocity_field]} class InverseConsistentAffine(RegistrationModule): """ @@ -212,8 +232,9 @@ def forward(self, image_A, image_B): concatenated_images_BA = torch.cat([image_B, image_A], axis=1) matrix_phi = self.net(concatenated_images_AB) - self.net(concatenated_images_BA) - matrix_phi = matrix_phi.reshape(image_A.shape[0], len(image_A.shape), len(image_A.shape) + 1) - + matrix_phi = matrix_phi.reshape( + image_A.shape[0], len(image_A.shape), len(image_A.shape) + 1 + ) matrix_phi_AB = torch.linalg.matrix_exp(matrix_phi) matrix_phi_BA = torch.linalg.matrix_exp(-matrix_phi) @@ -228,9 +249,9 @@ def transform_AB(tensor_of_coordinates): ], axis=1, ) - return multiply_matrix_vectorfield( - matrix_phi, coordinates_homogeneous - )[:, :-1] + return multiply_matrix_vectorfield(matrix_phi, coordinates_homogeneous)[ + :, :-1 + ] def transform_BA(tensor_of_coordinates): shape = list(tensor_of_coordinates.shape) @@ -242,8 +263,8 @@ def transform_BA(tensor_of_coordinates): ], axis=1, ) - return multiply_matrix_vectorfield( - matrix_phi_BA, coordinates_homogeneous - )[:, :-1] + return multiply_matrix_vectorfield(matrix_phi_BA, coordinates_homogeneous)[ + :, :-1 + ] return {"phi_AB": transform_AB, "phi_BA": transform_BA} diff --git a/ICON/icon_registration/networks.py b/ICON/icon_registration/networks.py index d0d1faf9..8cde1ae3 100644 --- a/ICON/icon_registration/networks.py +++ b/ICON/icon_registration/networks.py @@ -529,7 +529,7 @@ def tallerUNet2(dimension=2): def tallUNet2(dimension=2, input_channels=1): return UNet2( 5, - [[input_channels*2, 16, 32, 64, 256, 512], [16, 32, 64, 128, 256]], + [[input_channels * 2, 16, 32, 64, 256, 512], [16, 32, 64, 128, 256]], dimension, ) diff --git a/ICON/icon_registration/pretrained_models/HCP_brain.py b/ICON/icon_registration/pretrained_models/HCP_brain.py index 4d5fcff6..885e77c8 100644 --- a/ICON/icon_registration/pretrained_models/HCP_brain.py +++ b/ICON/icon_registration/pretrained_models/HCP_brain.py @@ -1,15 +1,19 @@ import itk from .lung_ct import init_network + def brain_network_preprocess(image: "itk.Image") -> "itk.Image": - if type(image) == itk.Image[itk.SS, 3] : - cast_filter = itk.CastImageFilter[itk.Image[itk.SS, 3], itk.Image[itk.F, 3]].New() + if type(image) == itk.Image[itk.SS, 3]: + cast_filter = itk.CastImageFilter[ + itk.Image[itk.SS, 3], itk.Image[itk.F, 3] + ].New() cast_filter.SetInput(image) cast_filter.Update() image = cast_filter.GetOutput() _, max_ = itk.image_intensity_min_max(image) - image = itk.shift_scale_image_filter(image, shift=0., scale = .9 / max_) + image = itk.shift_scale_image_filter(image, shift=0.0, scale=0.9 / max_) return image + def brain_registration_model(pretrained=True): return init_network("brain", pretrained=pretrained) diff --git a/ICON/icon_registration/pretrained_models/lung_ct.py b/ICON/icon_registration/pretrained_models/lung_ct.py index 8586daaa..c224b53b 100644 --- a/ICON/icon_registration/pretrained_models/lung_ct.py +++ b/ICON/icon_registration/pretrained_models/lung_ct.py @@ -9,22 +9,26 @@ def make_network(): dimension = 3 inner_net = network_wrappers.FunctionFromVectorField( - networks.tallUNet2(dimension=dimension)) + networks.tallUNet2(dimension=dimension) + ) for _ in range(2): inner_net = network_wrappers.TwoStepRegistration( - network_wrappers.DownsampleRegistration(inner_net, - dimension=dimension), + network_wrappers.DownsampleRegistration(inner_net, dimension=dimension), network_wrappers.FunctionFromVectorField( - networks.tallUNet2(dimension=dimension))) + networks.tallUNet2(dimension=dimension) + ), + ) inner_net = network_wrappers.TwoStepRegistration( inner_net, network_wrappers.FunctionFromVectorField( - networks.tallUNet2(dimension=dimension))) + networks.tallUNet2(dimension=dimension) + ), + ) - net = losses.GradientICONSparse(inner_net, - similarity=losses.LNCC(sigma=5), - lmbda=1.5) + net = losses.GradientICONSparse( + inner_net, similarity=losses.LNCC(sigma=5), lmbda=1.5 + ) return net @@ -46,12 +50,14 @@ def init_network(task, pretrained=True): if pretrained: from os.path import exists + weights_location = f"network_weights/{task}_model" if not exists(f"{weights_location}/{task}_model_weights.trch"): print("Downloading pretrained model") import urllib.request import os + download_path = "https://github.com/uncbiag/ICON/releases/download" download_path = f"{download_path}/pretrained_models_v1.0.0" @@ -73,25 +79,27 @@ def init_network(task, pretrained=True): return net -def lung_network_preprocess(image: "itk.Image", - segmentation: "itk.Image") -> "itk.Image": - +def lung_network_preprocess( + image: "itk.Image", segmentation: "itk.Image" +) -> "itk.Image": image = itk.clamp_image_filter(image, Bounds=(-1000, 0)) cast_filter = itk.CastImageFilter[type(image), itk.Image.F3].New() cast_filter.SetInput(image) cast_filter.Update() image = cast_filter.GetOutput() - segmentation_cast_filter = itk.CastImageFilter[type(segmentation), - itk.Image.F3].New() + segmentation_cast_filter = itk.CastImageFilter[ + type(segmentation), itk.Image.F3 + ].New() segmentation_cast_filter.SetInput(segmentation) segmentation_cast_filter.Update() segmentation = segmentation_cast_filter.GetOutput() image = itk.shift_scale_image_filter(image, shift=1000, scale=1 / 1000) - mask_filter = itk.MultiplyImageFilter[itk.Image.F3, itk.Image.F3, - itk.Image.F3].New() + mask_filter = itk.MultiplyImageFilter[ + itk.Image.F3, itk.Image.F3, itk.Image.F3 + ].New() mask_filter.SetInput1(image) mask_filter.SetInput2(segmentation) diff --git a/ICON/icon_registration/registration_module.py b/ICON/icon_registration/registration_module.py index 43168e45..2ec37451 100644 --- a/ICON/icon_registration/registration_module.py +++ b/ICON/icon_registration/registration_module.py @@ -46,7 +46,7 @@ def as_function(self, image): This allows translating the standard notation of registration papers more literally into code. - I \\circ \\Phi , the standard mathematical notation for a warped image, has the type + I \\circ \\Phi , the standard mathematical notation for a warped image, has the type "function from coordinates to intensities" and can be translated to the python code warped_image = lambda coords: self.as_function(I)(phi(coords)) @@ -56,9 +56,7 @@ def as_function(self, image): warped_image_tensor = warped_image(self.identity_map) """ - return lambda coordinates: self.warp( - image, coordinates - self.identity_map - ) + return lambda coordinates: self.warp(image, coordinates - self.identity_map) def assign_identity_map(self, input_shape, parents_identity_map=None): self.input_shape = input_shape @@ -80,22 +78,24 @@ def assign_identity_map(self, input_shape, parents_identity_map=None): child_shape, # None if self.downscale_factor != 1 else self.identity_map, ) + def make_ddf_from_icon_transform(self, transform): - """Compute A deformation field compatible with monai's Warp + """Compute A deformation field compatible with monai's Warp using an ICON transform. The assosciated ICON identity_map is also required """ return transform(self.identity_map) - self.identity_map + def make_ddf_using_icon_module(self, image_A, image_B): - """Compute a deformation field compatible with monai's Warp + """Compute a deformation field compatible with monai's Warp using an ICON RegistrationModule. If the RegistrationModule returns a transform, this function returns the monai version of that transform. If the RegistrationModule returns a loss, this function returns a monai version of the internal transform as well as the loss. """ res = self(image_A, image_B) - field = self.make_ddf_from_icon_transform(res["phi_AB"] - ) + field = self.make_ddf_from_icon_transform(res["phi_AB"]) return field, res + def forward(image_A, image_B): """Register a pair of images: return a python function phi_AB that warps a tensor of coordinates such that @@ -112,5 +112,3 @@ def forward(image_A, image_B): :return: :math:`\Phi^{AB}` """ raise NotImplementedError() - - diff --git a/ICON/icon_registration/similarity.py b/ICON/icon_registration/similarity.py index 60a78153..ef42146f 100644 --- a/ICON/icon_registration/similarity.py +++ b/ICON/icon_registration/similarity.py @@ -1,4 +1,3 @@ - def normalize(image): dimension = len(image.shape) - 2 if dimension == 2: @@ -14,17 +13,21 @@ class SimilarityBase: def __init__(self, isInterpolated=False): self.isInterpolated = isInterpolated + class NCC(SimilarityBase): def __init__(self): super().__init__(isInterpolated=False) def __call__(self, image_A, image_B): - assert image_A.shape == image_B.shape, "The shape of image_A and image_B sould be the same." + assert ( + image_A.shape == image_B.shape + ), "The shape of image_A and image_B sould be the same." A = normalize(image_A) B = normalize(image_B) res = torch.mean(A * B) return 1 - res + # torch removed this function from torchvision.functional_tensor, so we are vendoring it. def _get_gaussian_kernel1d(kernel_size, sigma): ksize_half = (kernel_size - 1) * 0.5 @@ -33,6 +36,7 @@ def _get_gaussian_kernel1d(kernel_size, sigma): kernel1d = pdf / pdf.sum() return kernel1d + def gaussian_blur(tensor, kernel_size, sigma, padding="same"): kernel1d = _get_gaussian_kernel1d(kernel_size=kernel_size, sigma=sigma).to( tensor.device, dtype=tensor.dtype @@ -41,14 +45,44 @@ def gaussian_blur(tensor, kernel_size, sigma, padding="same"): group = tensor.shape[1] if len(tensor.shape) - 2 == 1: - out = torch.conv1d(out, kernel1d[None, None, :].expand(group,-1,-1), padding="same", groups=group) + out = torch.conv1d( + out, + kernel1d[None, None, :].expand(group, -1, -1), + padding="same", + groups=group, + ) elif len(tensor.shape) - 2 == 2: - out = torch.conv2d(out, kernel1d[None, None, :, None].expand(group,-1,-1,-1), padding="same", groups=group) - out = torch.conv2d(out, kernel1d[None, None, None, :].expand(group,-1,-1,-1), padding="same", groups=group) + out = torch.conv2d( + out, + kernel1d[None, None, :, None].expand(group, -1, -1, -1), + padding="same", + groups=group, + ) + out = torch.conv2d( + out, + kernel1d[None, None, None, :].expand(group, -1, -1, -1), + padding="same", + groups=group, + ) elif len(tensor.shape) - 2 == 3: - out = torch.conv3d(out, kernel1d[None, None, :, None, None].expand(group,-1,-1,-1,-1), padding="same", groups=group) - out = torch.conv3d(out, kernel1d[None, None, None, :, None].expand(group,-1,-1,-1,-1), padding="same", groups=group) - out = torch.conv3d(out, kernel1d[None, None, None, None, :].expand(group,-1,-1,-1,-1), padding="same", groups=group) + out = torch.conv3d( + out, + kernel1d[None, None, :, None, None].expand(group, -1, -1, -1, -1), + padding="same", + groups=group, + ) + out = torch.conv3d( + out, + kernel1d[None, None, None, :, None].expand(group, -1, -1, -1, -1), + padding="same", + groups=group, + ) + out = torch.conv3d( + out, + kernel1d[None, None, None, None, :].expand(group, -1, -1, -1, -1), + padding="same", + groups=group, + ) return out @@ -85,7 +119,6 @@ def blur(self, tensor): return gaussian_blur(tensor, self.sigma * 4 + 1, self.sigma) def __call__(self, image_A, image_B): - I = image_A[:, :-1] J = image_B @@ -125,7 +158,9 @@ def blur(self, tensor): return gaussian_blur(tensor, self.sigma * 4 + 1, self.sigma) def __call__(self, image_A, image_B): - assert image_A.shape == image_B.shape, "The shape of image_A and image_B sould be the same." + assert ( + image_A.shape == image_B.shape + ), "The shape of image_A and image_B sould be the same." return torch.mean((self.blur(image_A) - self.blur(image_B)) ** 2) @@ -141,7 +176,10 @@ def blur(self, tensor): return gaussian_blur(tensor, self.sigma * 2 + 1, self.sigma) def __call__(self, image_A, image_B): - assert image_A.shape == image_B.shape, "The shape of image_A and image_B sould be the same." + assert ( + image_A.shape == image_B.shape + ), "The shape of image_A and image_B sould be the same." + def _nccBeforeMean(image_A, image_B): A = normalize(image_A) B = normalize(image_B) @@ -173,14 +211,18 @@ def _nccBeforeMean(image_A, image_B): return torch.mean(sim_loss) + class SSD(SimilarityBase): def __init__(self): super().__init__(isInterpolated=False) def __call__(self, image_A, image_B): - assert image_A.shape == image_B.shape, "The shape of image_A and image_B sould be the same." + assert ( + image_A.shape == image_B.shape + ), "The shape of image_A and image_B sould be the same." return torch.mean((image_A - image_B) ** 2) + class SSDOnlyInterpolated(SimilarityBase): def __init__(self): super().__init__(isInterpolated=True) @@ -195,10 +237,14 @@ def __call__(self, image_A, image_B): inbounds_mask = image_A[:, -1:] image_A = image_A[:, :-1] - assert image_A.shape == image_B.shape, "The shape of image_A and image_B sould be the same." + assert ( + image_A.shape == image_B.shape + ), "The shape of image_A and image_B sould be the same." inbounds_squared_distance = inbounds_mask * (image_A - image_B) ** 2 - sum_squared_distance = torch.sum(inbounds_squared_distance, dimensions_to_sum_over) + sum_squared_distance = torch.sum( + inbounds_squared_distance, dimensions_to_sum_over + ) divisor = torch.sum(inbounds_mask, dimensions_to_sum_over) ssds = sum_squared_distance / divisor return torch.mean(ssds) @@ -212,7 +258,7 @@ def flips(phi, in_percentage=False): dV = torch.sum(torch.cross(a, b, 1) * c, axis=1, keepdims=True) if in_percentage: - return torch.mean((dV < 0).float()) * 100. + return torch.mean((dV < 0).float()) * 100.0 else: return torch.sum(dV < 0) / phi.shape[0] elif len(phi.size()) == 4: @@ -220,15 +266,14 @@ def flips(phi, in_percentage=False): dv = (phi[:, :, :-1, 1:] - phi[:, :, :-1, :-1]).detach() dA = du[:, 0] * dv[:, 1] - du[:, 1] * dv[:, 0] if in_percentage: - return torch.mean((dA < 0).float()) * 100. + return torch.mean((dA < 0).float()) * 100.0 else: return torch.sum(dA < 0) / phi.shape[0] elif len(phi.size()) == 3: du = (phi[:, :, 1:] - phi[:, :, :-1]).detach() if in_percentage: - return torch.mean((du < 0).float()) * 100. + return torch.mean((du < 0).float()) * 100.0 else: return torch.sum(du < 0) / phi.shape[0] else: raise ValueError() - diff --git a/ICON/icon_registration/test_utils.py b/ICON/icon_registration/test_utils.py index 758e7849..16eec5e5 100644 --- a/ICON/icon_registration/test_utils.py +++ b/ICON/icon_registration/test_utils.py @@ -16,7 +16,7 @@ def download_test_data(): "61d3a99d4acac99f429277d7", str(TEST_DATA_DIR), ], - #stdout=sys.stdout, + # stdout=sys.stdout, ) diff --git a/ICON/icon_registration/train.py b/ICON/icon_registration/train.py index 873e12da..bf181c1a 100644 --- a/ICON/icon_registration/train.py +++ b/ICON/icon_registration/train.py @@ -6,6 +6,7 @@ from .losses import ICONLoss, to_floats import icon_registration.config + def write_stats(writer, stats: ICONLoss, ite): for k, v in to_floats(stats)._asdict().items(): writer.add_scalar(k, v, ite) @@ -62,7 +63,12 @@ def train_batchfunction( warped = [] with torch.no_grad(): for i in range(4): - print( unwrapped_net(visualization_moving[i:i + 1], visualization_fixed[i:i + 1])) + print( + unwrapped_net( + visualization_moving[i : i + 1], + visualization_fixed[i : i + 1], + ) + ) warped.append(unwrapped_net.warped_image_A.cpu()) warped = torch.cat(warped) unwrapped_net.train() @@ -77,10 +83,16 @@ def render(im): return im[:4, [0, 0, 0]].detach().cpu() writer.add_images( - "moving_image", render(visualization_moving[:4]), iteration, dataformats="NCHW" + "moving_image", + render(visualization_moving[:4]), + iteration, + dataformats="NCHW", ) writer.add_images( - "fixed_image", render(visualization_fixed[:4]), iteration, dataformats="NCHW" + "fixed_image", + render(visualization_fixed[:4]), + iteration, + dataformats="NCHW", ) writer.add_images( "warped_moving_image", @@ -90,13 +102,15 @@ def render(im): ) writer.add_images( "difference", - render(torch.clip((warped[:4, :1] - visualization_fixed[:4, :1].cpu()) + 0.5, 0, 1)), + render( + torch.clip( + (warped[:4, :1] - visualization_fixed[:4, :1].cpu()) + 0.5, 0, 1 + ) + ), iteration, dataformats="NCHW", ) - - def train_datasets(net, optimizer, d1, d2, epochs=400): """A training function for quick experiments""" diff --git a/ICON/test/test_brain_itk.py b/ICON/test/test_brain_itk.py index 9423aff5..249dc92e 100644 --- a/ICON/test/test_brain_itk.py +++ b/ICON/test/test_brain_itk.py @@ -31,12 +31,12 @@ def test_itk_registration(self): f"{icon_registration.test_utils.TEST_DATA_DIR}/brain_test_data/8_T1w_acpc_dc_restore_brain.nii.gz" ) - image_A_processed = icon_registration.pretrained_models.brain_network_preprocess( - image_A + image_A_processed = ( + icon_registration.pretrained_models.brain_network_preprocess(image_A) ) - image_B_processed = icon_registration.pretrained_models.brain_network_preprocess( - image_B + image_B_processed = ( + icon_registration.pretrained_models.brain_network_preprocess(image_B) ) phi_AB, phi_BA = icon_registration.itk_wrapper.register_pair( @@ -57,7 +57,9 @@ def test_itk_registration(self): ) plt.imshow( - np.array(itk.checker_board_image_filter(warped_image_A, image_B_processed))[40] + np.array(itk.checker_board_image_filter(warped_image_A, image_B_processed))[ + 40 + ] ) plt.colorbar() plt.savefig(footsteps.output_dir + "grid.png") @@ -66,8 +68,10 @@ def test_itk_registration(self): plt.savefig(footsteps.output_dir + "warped.png") plt.clf() - - reference = np.load(icon_registration.test_utils.TEST_DATA_DIR / "brain_test_data/2_and_8_warped_itkfix.npy") + reference = np.load( + icon_registration.test_utils.TEST_DATA_DIR + / "brain_test_data/2_and_8_warped_itkfix.npy" + ) np.save( footsteps.output_dir + "warped_brain.npy", diff --git a/ICON/test/test_knee_itk.py b/ICON/test/test_knee_itk.py index 204ce9ab..81b67624 100644 --- a/ICON/test/test_knee_itk.py +++ b/ICON/test/test_knee_itk.py @@ -71,7 +71,9 @@ def test_itk_registration(self): plt.savefig(footsteps.output_dir + "warped.png") plt.clf() - reference = np.load(icon_registration.test_utils.TEST_DATA_DIR / "warped_itkfix.npy") + reference = np.load( + icon_registration.test_utils.TEST_DATA_DIR / "warped_itkfix.npy" + ) np.save( footsteps.output_dir + "warped_knee.npy", @@ -81,6 +83,7 @@ def test_itk_registration(self): self.assertLess( np.mean(np.abs(reference - itk.array_from_image(warped_image_A)[40])), 1e-6 ) + def test_itk_consistency(self): import torch @@ -111,7 +114,7 @@ def test_itk_consistency(self): ) interpolator = itk.LinearInterpolateImageFunction.New(itk_img) - + warped_image_A = itk.resample_image_filter( itk_img, transform=phi, @@ -123,4 +126,3 @@ def test_itk_consistency(self): self.assertTrue( np.all(np.array(warped_image_A) == np.array(warped_image_torch)) ) - diff --git a/ICON/test/test_lung_itk.py b/ICON/test/test_lung_itk.py index f87e0e1e..90a1c850 100644 --- a/ICON/test/test_lung_itk.py +++ b/ICON/test/test_lung_itk.py @@ -11,11 +11,10 @@ class TestItkRegistration(unittest.TestCase): def test_itk_registration(self): - model = icon_registration.pretrained_models.LungCT_registration_model( pretrained=True ) - + icon_registration.test_utils.download_test_data() image_exp = itk.imread( @@ -73,6 +72,7 @@ def test_itk_registration(self): # log some images to show the registration import os + os.environ["FOOTSTEPS_NAME"] = "test" import footsteps @@ -123,4 +123,3 @@ def test_itk_registration(self): dists.append(np.sqrt(np.sum((px - py) ** 2))) print(np.mean(dists)) self.assertLess(np.mean(dists), 2.3) - From d806416436cc9498c1477c06fccbdec4e81bb966 Mon Sep 17 00:00:00 2001 From: Hastings Greer Date: Tue, 19 Dec 2023 15:50:44 -0500 Subject: [PATCH 06/15] start getting tests to pass --- ICON/icon_registration/__init__.py | 33 +++++++++++++------ ICON/icon_registration/losses.py | 18 ++++++++-- .../pretrained_models/OAI_knees.py | 2 +- ICON/icon_registration/train.py | 4 +-- ICON/test/test_2d_registration_train.py | 8 ++--- 5 files changed, 46 insertions(+), 19 deletions(-) diff --git a/ICON/icon_registration/__init__.py b/ICON/icon_registration/__init__.py index ee64a6f3..5ed2873b 100644 --- a/ICON/icon_registration/__init__.py +++ b/ICON/icon_registration/__init__.py @@ -1,21 +1,34 @@ from icon_registration.losses import ( + GradICON, + ICON, + BendingEnergy, + Diffusion, + VelocityFieldBendingEnergy, + VelocityFieldDiffusion, +) + +from icon_registration.similarity import ( LNCC, LNCCOnlyInterpolated, BlurredSSD, - GradientICON, - InverseConsistentNet, - gaussian_blur, - ssd_only_interpolated, - ssd, SSDOnlyInterpolated, SSD, NCC, ) from icon_registration.network_wrappers import ( - DownsampleRegistration, - FunctionFromMatrix, - FunctionFromVectorField, - RegistrationModule, - TwoStepRegistration, + InverseConsistentAffine, + InverseConsistentVelocityField, + Downsample, + TwoStep, + Affine, + VelocityField, + DisplacementField, ) from icon_registration.train import train_batchfunction, train_datasets + + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from icon_registration.registration_module import RegistrationModule diff --git a/ICON/icon_registration/losses.py b/ICON/icon_registration/losses.py index f4fbdbb8..c30032b4 100644 --- a/ICON/icon_registration/losses.py +++ b/ICON/icon_registration/losses.py @@ -4,9 +4,23 @@ import torch import torch.nn.functional as F -from icon_registration import config, network_wrappers +from icon_registration import config, network_wrappers, registration_module + +def to_floats(result_dictionary): + """ + Calling forward on the modules in this file returns a rich result dictionary of differentiable torch objects. This function + converts those to scalars suitable for logging in e.g. tensorboard + """ + + out = {} + for key, value in result_dictionary.items(): + if isinstance(value, float) or isinstance(value, int): + out[key] = value + elif isinstance(value, torch.Tensor) and value.shape == (): + out[key] = value.cpu().item() + return out + -import registration_module class Loss(registration_module.RegistrationModule): diff --git a/ICON/icon_registration/pretrained_models/OAI_knees.py b/ICON/icon_registration/pretrained_models/OAI_knees.py index 83cbab3c..4fdff8dd 100644 --- a/ICON/icon_registration/pretrained_models/OAI_knees.py +++ b/ICON/icon_registration/pretrained_models/OAI_knees.py @@ -7,7 +7,7 @@ from .. import networks from .lung_ct import init_network -from ..losses import SSD +from ..similarity import SSD def OAI_knees_registration_model(pretrained=True): diff --git a/ICON/icon_registration/train.py b/ICON/icon_registration/train.py index bf181c1a..93f7b9c6 100644 --- a/ICON/icon_registration/train.py +++ b/ICON/icon_registration/train.py @@ -3,11 +3,11 @@ import torch import tqdm -from .losses import ICONLoss, to_floats +from .losses import to_floats import icon_registration.config -def write_stats(writer, stats: ICONLoss, ite): +def write_stats(writer, stats, ite): for k, v in to_floats(stats)._asdict().items(): writer.add_scalar(k, v, ite) diff --git a/ICON/test/test_2d_registration_train.py b/ICON/test/test_2d_registration_train.py index e1583afc..21a719c1 100644 --- a/ICON/test/test_2d_registration_train.py +++ b/ICON/test/test_2d_registration_train.py @@ -31,8 +31,8 @@ def test_2d_registration_train(self): lmbda = 2048 print("ICON training") - net = icon_registration.InverseConsistentNet( - icon_registration.FunctionFromVectorField(networks.tallUNet2(dimension=2)), + net = icon_registration.ICON( + icon_registration.DisplacementField(networks.tallUNet2(dimension=2)), # Our image similarity metric. The last channel of x and y is whether the value is interpolated or extrapolated, # which is used by some metrics but not this one SSD(), @@ -85,8 +85,8 @@ def test_2d_registration_train_GradICON(self): lmbda = 1.0 print("GradientICON training") - net = icon_registration.GradientICON( - icon_registration.FunctionFromVectorField(networks.tallUNet2(dimension=2)), + net = icon_registration.GradICON( + icon_registration.DisplacementField(networks.tallUNet2(dimension=2)), # Our image similarity metric. The last channel of x and y is whether the value is interpolated or extrapolated, # which is used by some metrics but not this one SSD(), From 2408ebd53a24fad1a630aa1267ae093fcfb3fb40 Mon Sep 17 00:00:00 2001 From: Hastings Greer Date: Tue, 19 Dec 2023 15:55:31 -0500 Subject: [PATCH 07/15] some global find replaces --- ICON/icon_registration/losses.py | 5 ++--- .../pretrained_models/OAI_knees.py | 10 +++++----- .../icon_registration/pretrained_models/lung_ct.py | 14 ++++---------- ICON/test/test_2d_registration_train.py | 2 +- 4 files changed, 12 insertions(+), 19 deletions(-) diff --git a/ICON/icon_registration/losses.py b/ICON/icon_registration/losses.py index c30032b4..9b01ed5e 100644 --- a/ICON/icon_registration/losses.py +++ b/ICON/icon_registration/losses.py @@ -6,6 +6,7 @@ from icon_registration import config, network_wrappers, registration_module + def to_floats(result_dictionary): """ Calling forward on the modules in this file returns a rich result dictionary of differentiable torch objects. This function @@ -17,10 +18,8 @@ def to_floats(result_dictionary): if isinstance(value, float) or isinstance(value, int): out[key] = value elif isinstance(value, torch.Tensor) and value.shape == (): - out[key] = value.cpu().item() + out[key] = value.cpu().item() return out - - class Loss(registration_module.RegistrationModule): diff --git a/ICON/icon_registration/pretrained_models/OAI_knees.py b/ICON/icon_registration/pretrained_models/OAI_knees.py index 4fdff8dd..94fb224f 100644 --- a/ICON/icon_registration/pretrained_models/OAI_knees.py +++ b/ICON/icon_registration/pretrained_models/OAI_knees.py @@ -13,22 +13,22 @@ def OAI_knees_registration_model(pretrained=True): # The definition of our final 4 step registration network. - phi = icon_registration.FunctionFromVectorField( + phi = icon_registration.DisplacementField( networks.tallUNet(unet=networks.UNet2ChunkyMiddle, dimension=3) ) - psi = icon_registration.FunctionFromVectorField(networks.tallUNet2(dimension=3)) + psi = icon_registration.DisplacementField(networks.tallUNet2(dimension=3)) pretrained_lowres_net = icon_registration.TwoStepRegistration(phi, psi) hires_net = icon_registration.TwoStepRegistration( icon_registration.DownsampleRegistration(pretrained_lowres_net, dimension=3), - icon_registration.FunctionFromVectorField(networks.tallUNet2(dimension=3)), + icon_registration.DisplacementField(networks.tallUNet2(dimension=3)), ) - fourth_net = icon_registration.InverseConsistentNet( + fourth_net = icon_registration.ICON( icon_registration.TwoStepRegistration( hires_net, - icon_registration.FunctionFromVectorField(networks.tallUNet2(dimension=3)), + icon_registration.DisplacementField(networks.tallUNet2(dimension=3)), ), SSD(), 3600, diff --git a/ICON/icon_registration/pretrained_models/lung_ct.py b/ICON/icon_registration/pretrained_models/lung_ct.py index c224b53b..4abc9930 100644 --- a/ICON/icon_registration/pretrained_models/lung_ct.py +++ b/ICON/icon_registration/pretrained_models/lung_ct.py @@ -8,27 +8,21 @@ def make_network(): dimension = 3 - inner_net = network_wrappers.FunctionFromVectorField( + inner_net = network_wrappers.DisplacementField( networks.tallUNet2(dimension=dimension) ) for _ in range(2): inner_net = network_wrappers.TwoStepRegistration( network_wrappers.DownsampleRegistration(inner_net, dimension=dimension), - network_wrappers.FunctionFromVectorField( - networks.tallUNet2(dimension=dimension) - ), + network_wrappers.DisplacementField(networks.tallUNet2(dimension=dimension)), ) inner_net = network_wrappers.TwoStepRegistration( inner_net, - network_wrappers.FunctionFromVectorField( - networks.tallUNet2(dimension=dimension) - ), + network_wrappers.DisplacementField(networks.tallUNet2(dimension=dimension)), ) - net = losses.GradientICONSparse( - inner_net, similarity=losses.LNCC(sigma=5), lmbda=1.5 - ) + net = losses.GradICON(inner_net, similarity=losses.LNCC(sigma=5), lmbda=1.5) return net diff --git a/ICON/test/test_2d_registration_train.py b/ICON/test/test_2d_registration_train.py index 21a719c1..3f2149fd 100644 --- a/ICON/test/test_2d_registration_train.py +++ b/ICON/test/test_2d_registration_train.py @@ -84,7 +84,7 @@ def test_2d_registration_train_GradICON(self): lmbda = 1.0 - print("GradientICON training") + print("GradICON training") net = icon_registration.GradICON( icon_registration.DisplacementField(networks.tallUNet2(dimension=2)), # Our image similarity metric. The last channel of x and y is whether the value is interpolated or extrapolated, From 4a69fbf397e6cff0d59f4bbbd3c4522034d3344e Mon Sep 17 00:00:00 2001 From: Hastings Greer Date: Tue, 2 Jan 2024 15:38:21 -0500 Subject: [PATCH 08/15] work on tests --- ICON/icon_registration/losses.py | 45 +- ICON/icon_registration/networks.py | 481 +----------------- ICON/icon_registration/registration_module.py | 20 +- ICON/icon_registration/similarity.py | 2 + ICON/icon_registration/train.py | 3 +- ICON/test/test_2d_registration_train.py | 5 +- 6 files changed, 72 insertions(+), 484 deletions(-) diff --git a/ICON/icon_registration/losses.py b/ICON/icon_registration/losses.py index 9b01ed5e..6def2aee 100644 --- a/ICON/icon_registration/losses.py +++ b/ICON/icon_registration/losses.py @@ -6,6 +6,33 @@ from icon_registration import config, network_wrappers, registration_module +def flips(phi, in_percentage=False): + if len(phi.size()) == 5: + a = (phi[:, :, 1:, 1:, 1:] - phi[:, :, :-1, 1:, 1:]).detach() + b = (phi[:, :, 1:, 1:, 1:] - phi[:, :, 1:, :-1, 1:]).detach() + c = (phi[:, :, 1:, 1:, 1:] - phi[:, :, 1:, 1:, :-1]).detach() + + dV = torch.sum(torch.cross(a, b, 1) * c, axis=1, keepdims=True) + if in_percentage: + return torch.mean((dV < 0).float()) * 100. + else: + return torch.sum(dV < 0) / phi.shape[0] + elif len(phi.size()) == 4: + du = (phi[:, :, 1:, :-1] - phi[:, :, :-1, :-1]).detach() + dv = (phi[:, :, :-1, 1:] - phi[:, :, :-1, :-1]).detach() + dA = du[:, 0] * dv[:, 1] - du[:, 1] * dv[:, 0] + if in_percentage: + return torch.mean((dA < 0).float()) * 100. + else: + return torch.sum(dA < 0) / phi.shape[0] + elif len(phi.size()) == 3: + du = (phi[:, :, 1:] - phi[:, :, :-1]).detach() + if in_percentage: + return torch.mean((du < 0).float()) * 100. + else: + return torch.sum(du < 0) / phi.shape[0] + else: + raise ValueError() def to_floats(result_dictionary): """ @@ -52,7 +79,7 @@ def compute_similarity(self, image_A, image_B, phi_AB_vectorfield): if inbounds_tag is not None else image_A )( - self.phi_AB_vectorfield, + phi_AB_vectorfield, ) similarity_loss = self.similarity(warped_image_A, image_B) return {"similarity_loss": similarity_loss, "warped_image_A": warped_image_A} @@ -79,20 +106,20 @@ def forward(self, image_A, image_B): similarity_loss = ( similarity_AB["similarity_loss"] + similarity_BA["similarity_loss"] ) - regularization_loss = compute_regularizer(self, phi_AB, phi_BA) + regularization_loss = self.compute_regularizer(phi_AB, phi_BA) - all_loss = self.lmbda * gradient_inverse_consistency_loss + similarity_loss + all_loss = self.lmbda * regularization_loss + similarity_loss negative_jacobian_voxels = flips(phi_BA_vectorfield) return { "all_loss": all_loss, - "regularization_loss": inverse_consistency_loss, + "regularization_loss": regularization_loss, "similarity_loss": similarity_loss, "phi_AB": phi_AB, "phi_BA": phi_BA, "warped_image_A": similarity_AB["warped_image_A"], - "warped_image_B": similiarity_BA["warped_image_A"], + "warped_image_B": similarity_BA["warped_image_A"], "negative_jacobian_voxels": negative_jacobian_voxels, } @@ -100,12 +127,12 @@ def forward(self, image_A, image_B): class ICON(TwoWayRegularizer): def compute_regularizer(self, phi_AB, phi_BA): Iepsilon = self.identity_map + torch.randn(*self.identity_map.shape).to( - image_A.device + self.identity_map.device ) - approximate_Iepsilon1 = self.phi_AB(self.phi_BA(Iepsilon)) + approximate_Iepsilon1 = phi_AB(phi_BA(Iepsilon)) - approximate_Iepsilon2 = self.phi_BA(self.phi_AB(Iepsilon)) + approximate_Iepsilon2 = phi_BA(phi_AB(Iepsilon)) inverse_consistency_loss = torch.mean( (Iepsilon - approximate_Iepsilon1) ** 2 @@ -185,7 +212,7 @@ def forward(self, image_A, image_B): similarity_AB = self.compute_similarity(image_A, image_B, phi_AB_vectorfield) similarity_loss = 2 * similarity_AB["similarity_loss"] - regularization_loss = compute_regularizer(self, phi_AB_vectorfield) + regularization_loss = self.compute_regularizer(phi_AB_vectorfield) all_loss = self.lmbda * regularization_loss + similarity_loss diff --git a/ICON/icon_registration/networks.py b/ICON/icon_registration/networks.py index 8cde1ae3..0700cab7 100644 --- a/ICON/icon_registration/networks.py +++ b/ICON/icon_registration/networks.py @@ -32,8 +32,7 @@ def __init__(self, dimension=2, output_dim=100): self.dense2 = nn.Linear(256, 300) self.dense3 = nn.Linear(300, output_dim) - def forward(self, x, y): - x = torch.cat([x, y], 1) + def forward(self, x): for depth in range(len(self.features) - 1): x = F.relu(x) x = self.convs[depth](x) @@ -44,145 +43,6 @@ def forward(self, x, y): x = self.dense3(x) return x - -class Autoencoder(nn.Module): - def __init__(self, num_layers, channels): - super().__init__() - self.num_layers = num_layers - down_channels = channels[0] - up_channels = channels[1] - self.downConvs = nn.ModuleList([]) - self.upConvs = nn.ModuleList([]) - for depth in range(self.num_layers): - self.downConvs.append( - nn.Conv2d( - down_channels[depth], - down_channels[depth + 1], - kernel_size=3, - padding=1, - stride=2, - ) - ) - self.upConvs.append( - nn.ConvTranspose2d( - up_channels[depth + 1], - up_channels[depth], - kernel_size=4, - padding=1, - stride=2, - ) - ) - self.lastConv = nn.Conv2d(16, 2, kernel_size=3, padding=1) - torch.nn.init.zeros_(self.lastConv.weight) - - def forward(self, x, y): - x = torch.cat([x, y], 1) - skips = [] - for depth in range(self.num_layers): - skips.append(x) - x = F.relu(self.downConvs[depth](x)) - for depth in reversed(range(self.num_layers)): - x = F.relu(self.upConvs[depth](x)) - x = x[:, :, : skips[depth].size()[2], : skips[depth].size()[3]] - x = self.lastConv(x) - return x / 10 - - -def tallAE(): - return Autoencoder( - 5, - np.array( - [ - [2, 16, 32, 64, 256, 512], - [16, 32, 64, 128, 256, 512], - ] - ), - ) - - -class Residual(nn.Module): - def __init__(self, features): - super().__init__() - self.bn1 = nn.BatchNorm2d(num_features=features) - self.bn2 = nn.BatchNorm2d(num_features=features) - - self.conv1 = nn.Conv2d(features, features, kernel_size=3, padding=1) - self.conv2 = nn.Conv2d(features, features, kernel_size=3, padding=1) - - def forward(self, x): - y = F.relu(self.bn1(x)) - y = self.conv1(y) - y = F.relu(self.bn2(y)) - y = self.conv2(y) - return y + x - - -class UNet(nn.Module): - def __init__(self, num_layers, channels, dimension): - super().__init__() - - if dimension == 2: - self.BatchNorm = nn.BatchNorm2d - self.Conv = nn.Conv2d - self.ConvTranspose = nn.ConvTranspose2d - else: - self.BatchNorm = nn.BatchNorm3d - self.Conv = nn.Conv3d - self.ConvTranspose = nn.ConvTranspose3d - self.num_layers = num_layers - down_channels = np.array(channels[0]) - up_channels_out = np.array(channels[1]) - up_channels_in = down_channels[1:] + np.concatenate([up_channels_out[1:], [0]]) - self.downConvs = nn.ModuleList([]) - self.upConvs = nn.ModuleList([]) - # self.residues = nn.ModuleList([]) - self.batchNorms = nn.ModuleList( - [ - self.BatchNorm(num_features=up_channels_out[_]) - for _ in range(self.num_layers) - ] - ) - for depth in range(self.num_layers): - self.downConvs.append( - self.Conv( - down_channels[depth], - down_channels[depth + 1], - kernel_size=3, - padding=1, - stride=2, - ) - ) - self.upConvs.append( - self.ConvTranspose( - up_channels_in[depth], - up_channels_out[depth], - kernel_size=4, - padding=1, - stride=2, - ) - ) - # self.residues.append( - # Residual(up_channels_out[depth]) - # ) - self.lastConv = self.Conv(18, dimension, kernel_size=3, padding=1) - torch.nn.init.zeros_(self.lastConv.weight) - - def forward(self, x, y): - x = torch.cat([x, y], 1) - skips = [] - for depth in range(self.num_layers): - skips.append(x) - x = F.relu(self.downConvs[depth](x)) - for depth in reversed(range(self.num_layers)): - x = F.relu(self.upConvs[depth](x)) - x = self.batchNorms[depth](x) - - x = x[:, :, : skips[depth].size()[2], : skips[depth].size()[3]] - x = torch.cat([x, skips[depth]], 1) - x = self.lastConv(x) - return x / 10 - - def pad_or_crop(x, shape, dimension): y = x[:, : shape[1]] if x.size()[1] < shape[1]: @@ -217,7 +77,6 @@ def __init__(self, num_layers, channels, dimension): up_channels_in = down_channels[1:] + np.concatenate([up_channels_out[1:], [0]]) self.downConvs = nn.ModuleList([]) self.upConvs = nn.ModuleList([]) - # self.residues = nn.ModuleList([]) self.batchNorms = nn.ModuleList( [ self.BatchNorm(num_features=up_channels_out[_]) @@ -243,17 +102,13 @@ def __init__(self, num_layers, channels, dimension): stride=2, ) ) - # self.residues.append( - # Residual(up_channels_out[depth]) - # ) self.lastConv = self.Conv( down_channels[0] + up_channels_out[0], dimension, kernel_size=3, padding=1 ) torch.nn.init.zeros_(self.lastConv.weight) torch.nn.init.zeros_(self.lastConv.bias) - def forward(self, x, y): - x = torch.cat([x, y], 1) + def forward(self, x): skips = [] for depth in range(self.num_layers): skips.append(x) @@ -270,7 +125,6 @@ def forward(self, x, y): mode=self.interpolate_mode, align_corners=False, ) - # x = self.residues[depth](x) x = self.batchNorms[depth](x) if self.dimension == 2: x = x[:, :, : skips[depth].size()[2], : skips[depth].size()[3]] @@ -284,10 +138,10 @@ def forward(self, x, y): ] x = torch.cat([x, skips[depth]], 1) x = self.lastConv(x) - return x / 10 + return x * 5 -class UNet2ChunkyMiddle(nn.Module): +class UNetDenseMiddle(nn.Module): def __init__(self, num_layers, channels, dimension): super().__init__() self.dimension = dimension @@ -309,7 +163,6 @@ def __init__(self, num_layers, channels, dimension): up_channels_in = down_channels[1:] + np.concatenate([up_channels_out[1:], [0]]) self.downConvs = nn.ModuleList([]) self.upConvs = nn.ModuleList([]) - # self.residues = nn.ModuleList([]) self.batchNorms = nn.ModuleList( [ self.BatchNorm(num_features=up_channels_out[_]) @@ -335,9 +188,6 @@ def __init__(self, num_layers, channels, dimension): stride=2, ) ) - # self.residues.append( - # Residual(up_channels_out[depth]) - # ) self.lastConv = self.Conv(18, dimension, kernel_size=3, padding=1) torch.nn.init.zeros_(self.lastConv.weight) @@ -348,8 +198,7 @@ def __init__(self, num_layers, channels, dimension): ] ) - def forward(self, x, y): - x = torch.cat([x, y], 1) + def forward(self, x): skips = [] for depth in range(self.num_layers): skips.append(x) @@ -370,7 +219,6 @@ def forward(self, x, y): mode=self.interpolate_mode, align_corners=False, ) - # x = self.residues[depth](x) x = self.batchNorms[depth](x) x = x[:, :, : skips[depth].size()[2], : skips[depth].size()[3]] @@ -379,151 +227,6 @@ def forward(self, x, y): return x / 10 -class UNet3(nn.Module): - def __init__(self, num_layers, channels, dimension, normalization): - super().__init__() - - self.dimension = dimension - if dimension == 2: - self.BatchNorm = nn.BatchNorm2d - self.Conv = nn.Conv2d - self.ConvTranspose = nn.ConvTranspose2d - self.avg_pool = F.avg_pool2d - self.interpolate_mode = "bilinear" - else: - self.BatchNorm = nn.BatchNorm3d - self.Conv = nn.Conv3d - self.ConvTranspose = nn.ConvTranspose3d - self.avg_pool = F.avg_pool3d - self.interpolate_mode = "trilinear" - self.num_layers = num_layers - down_channels = np.array(channels[0]) - up_channels_out = np.array(channels[1]) - up_channels_in = down_channels[1:] + np.concatenate([up_channels_out[1:], [0]]) - self.downConvs = nn.ModuleList([]) - self.upConvs = nn.ModuleList([]) - - # More traditional residual structure - # self.down_1x1s = nn.ModuleList([]) - # self.up_1x1s = nn.ModuleList([]) - - # self.residues = nn.ModuleList([]) - self.normalization = normalization - if self.normalization == "batchnorm": - self.batchNorms = nn.ModuleList( - [ - self.BatchNorm(num_features=up_channels_out[_]) - for _ in range(self.num_layers) - ] - ) - if self.normalization == "groupnorm": - self.groupNorms = nn.ModuleList( - [ - nn.GroupNorm( - max(16, up_channels_out[depth]), up_channels_out[depth] - ) - for depth in range(self.num_layers) - ] - ) - for depth in range(self.num_layers): - self.downConvs.append( - self.Conv( - down_channels[depth], - down_channels[depth + 1], - kernel_size=3, - padding=1, - stride=2, - ) - ) - # self.down_1x1s.append( - # self.Conv( - # down_channels[depth + 1], - # down_channels[depth + 1], - # kernel_size=3, - # padding=1, - # stride=1, - # ) - # ) - self.upConvs.append( - self.ConvTranspose( - up_channels_in[depth], - up_channels_out[depth], - kernel_size=4, - padding=1, - stride=2, - ) - ) - # self.up_1x1s.append( - # self.Conv( - # up_channels_out[depth], - # up_channels_out[depth], - # kernel_size=3, - # padding=1, - # stride=1, - # ) - # ) - - # self.residues.append( - # Residual(up_channels_out[depth]) - # ) - self.lastConv = self.Conv(18, dimension, kernel_size=3, padding=1) - torch.nn.init.zeros_(self.lastConv.weight) - - def forward(self, x, y): - x = torch.cat([x, y], 1) - skips = [] - for depth in range(self.num_layers): - skips.append(x) - y = self.downConvs[depth](F.leaky_relu(x)) - # y = self.down_1x1s[depth](F.leaky_relu(y)) - x = y + pad_or_crop( - self.avg_pool(x, 2, ceil_mode=True), y.size(), self.dimension - ) - y = F.layer_norm - - for depth in reversed(range(self.num_layers)): - y = self.upConvs[depth](F.leaky_relu(x)) - # y = self.up_1x1s[depth](F.leaky_relu(y)) - x = y + F.interpolate( - pad_or_crop(x, y.size(), self.dimension), - scale_factor=2, - mode=self.interpolate_mode, - align_corners=False, - ) - # x = self.residues[depth](x) - if self.normalization == "batchnorm": - x = self.batchNorms[depth](x) - - if self.normalization == "groupnorm": - x = self.groupNorms[depth](x) - x = x[:, :, : skips[depth].size()[2], : skips[depth].size()[3]] - x = torch.cat([x, skips[depth]], 1) - x = self.lastConv(x) - return x / 10 - - -def tallUNet(unet=UNet, dimension=2): - return unet( - 5, - [[2, 16, 32, 64, 256, 512], [16, 32, 64, 128, 256]], - dimension, - ) - - -def tallishUNet2(dimension=2): - return UNet2( - 6, - [[2, 16, 32, 64, 256, 512, 512], [16, 32, 64, 128, 256, 512]], - dimension, - ) - - -def tallerUNet2(dimension=2): - return UNet2( - 7, - [[2, 16, 32, 64, 256, 512, 512, 512], [16, 32, 64, 128, 256, 512, 512]], - dimension, - ) def tallUNet2(dimension=2, input_channels=1): @@ -534,37 +237,6 @@ def tallUNet2(dimension=2, input_channels=1): ) -def tallUNet3(normalization="batchnorm", dimension=2): - return UNet3( - 5, - [[2, 16, 32, 64, 256, 512], [16, 32, 64, 128, 256]], - dimension, - normalization=normalization, - ) - - -class RegisNet(nn.Module): - def __init__(self): - super().__init__() - self.conv1 = nn.Conv2d(2, 10, kernel_size=5, padding=2) - self.conv2 = nn.Conv2d(12, 10, kernel_size=5, padding=2) - self.conv3 = nn.Conv2d(22, 10, kernel_size=5, padding=2) - self.conv4 = nn.Conv2d(32, 10, kernel_size=5, padding=2) - self.conv5 = nn.Conv2d(42, 10, kernel_size=5, padding=2) - self.conv6 = nn.Conv2d(52, 2, kernel_size=5, padding=2) - - def forward(self, x, y): - x = torch.cat([x, y], 1) - - x = torch.cat([x, F.relu(self.conv1(x))], 1) - x = torch.cat([x, F.relu(self.conv2(x))], 1) - x = torch.cat([x, F.relu(self.conv3(x))], 1) - x = torch.cat([x, F.relu(self.conv4(x))], 1) - x = torch.cat([x, F.relu(self.conv5(x))], 1) - - return self.conv6(x) - - class FCNet1D(nn.Module): def __init__(self, size=28): super().__init__() @@ -574,8 +246,8 @@ def __init__(self, size=28): self.dense3 = nn.Linear(3000, size) torch.nn.init.zeros_(self.dense3.weight) - def forward(self, x, y): - x = torch.reshape(torch.cat([x, y], 1), (-1, 2 * self.size)) + def forward(self, x): + x = torch.reshape(x, (-1, 2 * self.size)) x = F.relu(self.dense1(x)) x = F.relu(self.dense2(x)) x = self.dense3(x) @@ -592,15 +264,14 @@ def __init__(self, size=28): self.dense3 = nn.Linear(3000, size * size * 2) torch.nn.init.zeros_(self.dense3.weight) - def forward(self, x, y): - x = torch.reshape(torch.cat([x, y], 1), (-1, 2 * self.size * self.size)) + def forward(self, x): + x = torch.reshape(x, (-1, 2 * self.size * self.size)) x = F.relu(self.dense1(x)) x = F.relu(self.dense2(x)) x = self.dense3(x) x = torch.reshape(x, (-1, 2, self.size, self.size)) return x - class FCNet3D(nn.Module): def __init__(self, shape, bottleneck=128): super().__init__() @@ -614,8 +285,8 @@ def __init__(self, shape, bottleneck=128): torch.nn.init.zeros_(self.dense4.weight) torch.nn.init.zeros_(self.dense4.bias) - def forward(self, x, y): - x = torch.reshape(torch.cat([x, y], 1), (-1, 2 * np.product(self.shape[2:]))) + def forward(self, x): + x = torch.reshape(x, (-1, 2 * np.product(self.shape[2:]))) x = F.relu(self.dense1(x)) x = F.relu(self.dense2(x)) x = F.relu(self.dense3(x)) @@ -624,50 +295,6 @@ def forward(self, x, y): return x -class DenseMatrixNet(nn.Module): - def __init__(self, size=28, dimension=2): - super().__init__() - self.dimension = dimension - self.size = size - self.dense1 = nn.Linear(size * size * 2, 800) - self.dense2 = nn.Linear(800, 300) - self.dense3 = nn.Linear(300, 6 if self.dimension == 2 else 12) - torch.nn.init.zeros_(self.dense3.weight) - - def forward(self, x, y): - x = torch.reshape(torch.cat([x, y], 1), (-1, 2 * self.size * self.size)) - x = F.relu(self.dense1(x)) - x = F.relu(self.dense2(x)) - x = self.dense3(x) - if self.dimension == 3: - x = torch.reshape(x, (-1, 3, 4)) - x = torch.cat( - [ - x, - torch.Tensor([[[0, 0, 0, 1]]]) - .to(x.device) - .expand(x.shape[0], -1, -1), - ], - 1, - ) - x = x + torch.Tensor( - [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 0]] - ).to(x.device) - elif self.dimension == 2: - x = torch.reshape(x, (-1, 2, 3)) - x = torch.cat( - [ - x, - torch.Tensor([[[0, 0, 1]]]).to(x.device).expand(x.shape[0], -1, -1), - ], - 1, - ) - x = x + torch.Tensor([[1, 0, 0], [0, 1, 0], [0, 0, 0]]).to(x.device) - else: - raise ArgumentError() - return x - - class ConvolutionalMatrixNet(nn.Module): def __init__(self, dimension=2): super().__init__() @@ -696,8 +323,7 @@ def __init__(self, dimension=2): torch.nn.init.zeros_(self.dense3.weight) torch.nn.init.zeros_(self.dense3.bias) - def forward(self, x, y): - x = torch.cat([x, y], 1) + def forward(self, x): for depth in range(len(self.features) - 1): x = F.relu(x) x = self.convs[depth](x) @@ -754,86 +380,3 @@ def forward(self, x, y): return x -class StumpyConvolutionalMatrixNet(nn.Module): - def __init__(self, dimension=2): - super().__init__() - self.dimension = dimension - - if dimension == 2: - self.Conv = nn.Conv2d - self.avg_pool = F.avg_pool2d - else: - self.Conv = nn.Conv3d - self.avg_pool = F.avg_pool3d - - self.features = [2, 16, 32, 64, 128, 256] - self.convs = nn.ModuleList([]) - for depth in range(len(self.features) - 1): - self.convs.append( - self.Conv( - self.features[depth], - self.features[depth + 1], - kernel_size=3, - padding=1, - ) - ) - self.dense2 = nn.Linear(256 * 2 * 3 * 3, 3000) - self.dense3 = nn.Linear(3000, 6 if self.dimension == 2 else 12) - torch.nn.init.zeros_(self.dense3.weight) - torch.nn.init.zeros_(self.dense3.bias) - - def forward(self, x, y): - x = torch.cat([x, y], 1) - for depth in range(len(self.features) - 1): - x = F.relu(x) - x = self.convs[depth](x) - x = self.avg_pool(x, 2, ceil_mode=True) - x = torch.reshape(x, (-1, 256 * 2 * 3 * 3)) - x = F.relu(self.dense2(x)) - x = self.dense3(x) - if self.dimension == 3: - x = torch.reshape(x, (-1, 3, 4)) - x = torch.cat( - [ - x, - torch.Tensor([[[0, 0, 0, 1]]]) - .to(x.device) - .expand(x.shape[0], -1, -1), - ], - 1, - ) - x = x + torch.Tensor( - [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 0]] - ).to(x.device) - x = torch.matmul( - torch.Tensor( - [[1, 0, 0, 0.5], [0, 1, 0, 0.5], [0, 0, 1, 0.5], [0, 0, 0, 1]] - ).to(x.device), - x, - ) - x = torch.matmul( - x, - torch.Tensor( - [[1, 0, 0, -0.5], [0, 1, 0, -0.5], [0, 0, 1, -0.5], [0, 0, 0, 1]] - ).to(x.device), - ) - elif self.dimension == 2: - x = torch.reshape(x, (-1, 2, 3)) - x = torch.cat( - [ - x, - torch.Tensor([[[0, 0, 1]]]).to(x.device).expand(x.shape[0], -1, -1), - ], - 1, - ) - x = x + torch.Tensor([[1, 0, 0], [0, 1, 0], [0, 0, 0]]).to(x.device) - x = torch.matmul( - torch.Tensor([[1, 0, 0.5], [0, 1, 0.5], [0, 0, 1]]).to(x.device), x - ) - x = torch.matmul( - x, - torch.Tensor([[1, 0, -0.5], [0, 1, -0.5], [0, 0, 1]]).to(x.device), - ) - else: - raise ArgumentError() - return x diff --git a/ICON/icon_registration/registration_module.py b/ICON/icon_registration/registration_module.py index 2ec37451..dbe71e26 100644 --- a/ICON/icon_registration/registration_module.py +++ b/ICON/icon_registration/registration_module.py @@ -4,6 +4,7 @@ from torch import nn from monai.networks.blocks import Warp +from monai.networks.utils import meshgrid_ij class RegistrationModule(nn.Module): @@ -39,6 +40,12 @@ def __init__(self): self.warp = Warp() self.identity_map = None + def _make_identity_map(self, shape): + mesh_points = [torch.arange(0, dim) for dim in shape[2:]] + grid = torch.stack(meshgrid_ij(*mesh_points), dim=0) # (spatial_dims, ...) + grid = torch.stack([grid], dim=0).float() # (batch, spatial_dims, ...) + return grid + def as_function(self, image): """image is a (potentially vector valued) tensor with shape self.input_shape. Returns a python function that maps a tensor of coordinates [batch x N_dimensions x ...] @@ -56,12 +63,19 @@ def as_function(self, image): warped_image_tensor = warped_image(self.identity_map) """ - return lambda coordinates: self.warp(image, coordinates - self.identity_map) + def partially_applied_warp(coordinates): + coordinates_shape = list(coordinates.shape) + coordinates_shape[0] = image.shape[0] + coordinates = torch.broadcast_to(coordinates, coordinates_shape) + return self.warp(image, coordinates.clone()) + + return partially_applied_warp def assign_identity_map(self, input_shape, parents_identity_map=None): self.input_shape = input_shape - _id = self.warp.get_reference_grid(input_shape) - self.register_buffer("identity_map", _id, persistent=False) + grid = self._make_identity_map(input_shape) + del self.identity_map + self.register_buffer("identity_map", grid, persistent=False) if self.downscale_factor != 1: child_shape = np.concatenate( diff --git a/ICON/icon_registration/similarity.py b/ICON/icon_registration/similarity.py index ef42146f..e49c9c12 100644 --- a/ICON/icon_registration/similarity.py +++ b/ICON/icon_registration/similarity.py @@ -1,3 +1,5 @@ +import torch + def normalize(image): dimension = len(image.shape) - 2 if dimension == 2: diff --git a/ICON/icon_registration/train.py b/ICON/icon_registration/train.py index 93f7b9c6..721abcab 100644 --- a/ICON/icon_registration/train.py +++ b/ICON/icon_registration/train.py @@ -125,10 +125,11 @@ def train_datasets(net, optimizer, d1, d2, epochs=400): loss_object = net(image_A, image_B) - loss_object.all_loss.backward() + loss_object["all_loss"].backward() optimizer.step() loss_history.append(to_floats(loss_object)) + print(to_floats(loss_object)) return loss_history diff --git a/ICON/test/test_2d_registration_train.py b/ICON/test/test_2d_registration_train.py index 3f2149fd..36dc2869 100644 --- a/ICON/test/test_2d_registration_train.py +++ b/ICON/test/test_2d_registration_train.py @@ -42,14 +42,15 @@ def test_2d_registration_train(self): input_shape = list(next(iter(d1))[0].size()) input_shape[0] = 1 net.assign_identity_map(input_shape) + print(net.identity_map) net.cuda() optimizer = torch.optim.Adam(net.parameters(), lr=0.001) net.train() - y = icon_registration.train_datasets(net, optimizer, d1, d2, epochs=5) + y = icon_registration.train_datasets(net, optimizer, d1, d2, epochs=50) # Test that image similarity is good enough - self.assertLess(np.mean(np.array(y)[-5:, 1]), 0.1) + self.assertLess(np.mean(np.array([step["similarity_loss"] for step in y])[-5:]), 0.1) # Test that folds are rare enough self.assertLess(np.mean(np.exp(np.array(y)[-5:, 3] - 0.1)), 2) From ec959ef0ccdfc939783f7c9ec0fe0dbc0d353dc5 Mon Sep 17 00:00:00 2001 From: Thomas Greer Date: Fri, 5 Jan 2024 09:46:00 -0500 Subject: [PATCH 09/15] monai requirement --- ICON/requirements.txt | 1 + ICON/setup.cfg | 1 + 2 files changed, 2 insertions(+) diff --git a/ICON/requirements.txt b/ICON/requirements.txt index 5d357d03..39d9d99f 100644 --- a/ICON/requirements.txt +++ b/ICON/requirements.txt @@ -3,6 +3,7 @@ torchvision tensorboard tqdm matplotlib +monai footsteps>=0.1.6 itk==5.3.0 girder_client==3.1.8 diff --git a/ICON/setup.cfg b/ICON/setup.cfg index 83119cf1..297272fd 100644 --- a/ICON/setup.cfg +++ b/ICON/setup.cfg @@ -21,6 +21,7 @@ install_requires = tensorboard tqdm matplotlib + monai footsteps>=0.1.6 itk==5.3.0 girder_client==3.1.8 From a0eaac1502f0c292e2720ac0f3bb677e8b635002 Mon Sep 17 00:00:00 2001 From: Thomas Greer Date: Fri, 5 Jan 2024 12:32:32 -0500 Subject: [PATCH 10/15] plotting in test --- ICON/test/test_2d_registration_train.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ICON/test/test_2d_registration_train.py b/ICON/test/test_2d_registration_train.py index 36dc2869..453e11f6 100644 --- a/ICON/test/test_2d_registration_train.py +++ b/ICON/test/test_2d_registration_train.py @@ -47,7 +47,11 @@ def test_2d_registration_train(self): optimizer = torch.optim.Adam(net.parameters(), lr=0.001) net.train() - y = icon_registration.train_datasets(net, optimizer, d1, d2, epochs=50) + y = icon_registration.train_datasets(net, optimizer, d1, d2, epochs=5) + import matplotlib.pyplot as plt + import footsteps + plt.plot([step["similarity_loss"] for step in y]) + footsteps.plot("2d-icon-similarity") # Test that image similarity is good enough self.assertLess(np.mean(np.array([step["similarity_loss"] for step in y])[-5:]), 0.1) From 15ad6f92088ccf7e2ca58b1caadec07a910ffdef Mon Sep 17 00:00:00 2001 From: Hastings Greer Date: Fri, 5 Jan 2024 12:33:52 -0500 Subject: [PATCH 11/15] viz --- ICON/icon_registration/visualize.py | 33 +++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 ICON/icon_registration/visualize.py diff --git a/ICON/icon_registration/visualize.py b/ICON/icon_registration/visualize.py new file mode 100644 index 00000000..e0460c7d --- /dev/null +++ b/ICON/icon_registration/visualize.py @@ -0,0 +1,33 @@ +import torch +import matplotlib.pyplot as plt + +def show(tensor): + plt.imshow(torchvision.utils.make_grid(tensor[:6], nrow=3)[0].cpu().detach()) + plt.xticks([]) + plt.yticks([]) + +def render(im): + if len(im.shape) == 5: + im = im[:, :, :, :, im.shape[4] // 2] + if torch.min(im) < 0: + im = im - torch.min(im) + if torch.max(im) > 1: + im = im / torch.max(im) + return im[:4, [0, 0, 0]].detach().cpu() + +image_A = next(iter(ds))[0].to(device) +image_B = next(iter(ds))[0].to(device) + +def plot_registration_result(image_A, image_B, registration_result) + + plt.subplot(2, 2, 1) + show(image_A) + plt.subplot(2, 2, 2) + show(image_B) + plt.subplot(2, 2, 3) + show(net.warped_image_A) + plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[0].cpu().detach()) + plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[1].cpu().detach()) + plt.subplot(2, 2, 4) + show(net.warped_image_A - image_B) + plt.tight_layout() From 79cc70eee5302bfb6a87cebc419945a9f3137f1c Mon Sep 17 00:00:00 2001 From: Thomas Greer Date: Fri, 5 Jan 2024 14:05:40 -0500 Subject: [PATCH 12/15] work --- ICON/icon_registration/registration_module.py | 35 ++++++++++++++++++- ICON/icon_registration/visualize.py | 23 ++++++++---- ICON/test/test_2d_registration_train.py | 6 +++- 3 files changed, 55 insertions(+), 9 deletions(-) diff --git a/ICON/icon_registration/registration_module.py b/ICON/icon_registration/registration_module.py index dbe71e26..b568d36f 100644 --- a/ICON/icon_registration/registration_module.py +++ b/ICON/icon_registration/registration_module.py @@ -6,6 +6,39 @@ from monai.networks.blocks import Warp from monai.networks.utils import meshgrid_ij +class CoordinateWarp(Warp): + def forward(self, image: torch.tensor, ddf: torch.tensor): + """ + args: + image: tensor in shape (batch, num_channels, h, w[, d]) + ddf: tensor in the same spatial size as image, in shape (batch, ``spatial_dims``, h, w[, d]) + + returns: + warped_image in the same shape as image (batch, num_channels, h, w[, d]) + """ + spatial_dims = len(image.shape) - 2 + if spatial_dims not in (2, 3): + raise notimplementederror(f"got unsupported spatial_dims={spatial_dims}, currently support 2 or 3.") + ddf_shape = (image.shape[0], spatial_dims) + tuple(image.shape[2:]) + if ddf.shape != ddf_shape: + raise valueerror( + f"given input {spatial_dims}-d image shape {image.shape}, the input ddf shape must be {ddf_shape}, " + f"got {ddf.shape} instead." + ) + grid = ddf # assume + grid = grid.permute([0] + list(range(2, 2 + spatial_dims)) + [1]) # (batch, ..., spatial_dims) + + for i, dim in enumerate(grid.shape[1:-1]): + grid[..., i] = grid[..., i] * 2 / (dim - 1) - 1 + index_ordering: list[int] = list(range(spatial_dims - 1, -1, -1)) + grid = grid[..., index_ordering] # z, y, x -> x, y, z + return F.grid_sample( + image, grid, mode=self._interp_mode, padding_mode=f"{self._padding_mode}", align_corners=True + ) + + + + class RegistrationModule(nn.Module): r"""Base class for icon modules that perform registration. @@ -37,7 +70,7 @@ class RegistrationModule(nn.Module): def __init__(self): super().__init__() self.downscale_factor = 1 - self.warp = Warp() + self.warp = CoordinateWarp() self.identity_map = None def _make_identity_map(self, shape): diff --git a/ICON/icon_registration/visualize.py b/ICON/icon_registration/visualize.py index e0460c7d..af729faf 100644 --- a/ICON/icon_registration/visualize.py +++ b/ICON/icon_registration/visualize.py @@ -1,5 +1,7 @@ import torch +import torchvision import matplotlib.pyplot as plt +import footsteps def show(tensor): plt.imshow(torchvision.utils.make_grid(tensor[:6], nrow=3)[0].cpu().detach()) @@ -15,19 +17,26 @@ def render(im): im = im / torch.max(im) return im[:4, [0, 0, 0]].detach().cpu() -image_A = next(iter(ds))[0].to(device) -image_B = next(iter(ds))[0].to(device) -def plot_registration_result(image_A, image_B, registration_result) +def plot_registration_result(image_A, image_B, net) : + + res = net(image_A, image_B) + print(res.keys()) + + + plt.subplot(2, 2, 1) show(image_A) plt.subplot(2, 2, 2) show(image_B) plt.subplot(2, 2, 3) - show(net.warped_image_A) - plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[0].cpu().detach()) - plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[1].cpu().detach()) + show(res["warped_image_A"]) + + phi_AB_vectorfield = res["phi_AB"](net.identity_map) + plt.contour(torchvision.utils.make_grid(phi_AB_vectorfield[:6], nrow=3)[0].cpu().detach()) + plt.contour(torchvision.utils.make_grid(phi_AB_vectorfield[:6], nrow=3)[1].cpu().detach()) plt.subplot(2, 2, 4) - show(net.warped_image_A - image_B) + show(res["warped_image_A"]- image_B) plt.tight_layout() + footsteps.plot() diff --git a/ICON/test/test_2d_registration_train.py b/ICON/test/test_2d_registration_train.py index 453e11f6..385ec8b6 100644 --- a/ICON/test/test_2d_registration_train.py +++ b/ICON/test/test_2d_registration_train.py @@ -2,12 +2,13 @@ class Test2DRegistrationTrain(unittest.TestCase): - def test_2d_registration_train(self): + def test_2d_registration_train_ICON(self): import icon_registration import icon_registration.data as data import icon_registration.networks as networks from icon_registration import SSD + import icon_registration.visualize import numpy as np import torch @@ -53,6 +54,9 @@ def test_2d_registration_train(self): plt.plot([step["similarity_loss"] for step in y]) footsteps.plot("2d-icon-similarity") + test_A, test_B = [next(iter(d1))[0].cuda() for _ in range(2)] + icon_registration.visualize.plot_registration_result(test_A, test_B, net) + # Test that image similarity is good enough self.assertLess(np.mean(np.array([step["similarity_loss"] for step in y])[-5:]), 0.1) From 05335d4bebe02826dcc1a97430990b52e9c59fe5 Mon Sep 17 00:00:00 2001 From: Thomas Greer Date: Mon, 8 Jan 2024 10:41:00 -0500 Subject: [PATCH 13/15] gradicon_test_almost_passes --- ICON/icon_registration/losses.py | 2 +- ICON/icon_registration/network_wrappers.py | 4 +--- ICON/icon_registration/networks.py | 2 +- ICON/icon_registration/registration_module.py | 5 ----- 4 files changed, 3 insertions(+), 10 deletions(-) diff --git a/ICON/icon_registration/losses.py b/ICON/icon_registration/losses.py index 6def2aee..e61aaa9f 100644 --- a/ICON/icon_registration/losses.py +++ b/ICON/icon_registration/losses.py @@ -161,7 +161,7 @@ def compute_regularizer(self, phi_AB, phi_BA): inverse_consistency_error = Iepsilon - approximate_Iepsilon - delta = 0.001 + delta = 0.5 if len(self.identity_map.shape) == 4: dx = torch.Tensor([[[[delta]], [[0.0]]]]).to(self.identity_map.device) diff --git a/ICON/icon_registration/network_wrappers.py b/ICON/icon_registration/network_wrappers.py index baa82d07..14c193e7 100644 --- a/ICON/icon_registration/network_wrappers.py +++ b/ICON/icon_registration/network_wrappers.py @@ -108,9 +108,7 @@ def __init__(self, netPhi, netPsi): def forward(self, image_A, image_B): phi = self.netPhi(image_A, image_B) - psi = self.netPsi(self.as_function(image_A)(phi(self.identity_map)), image_B)[ - "phi_AB" - ] + psi = self.netPsi(self.as_function(image_A)(phi["phi_AB"](self.identity_map)), image_B) result = { "phi_AB": lambda tensor_of_coordinates: phi["phi_AB"]( psi["phi_AB"](tensor_of_coordinates) diff --git a/ICON/icon_registration/networks.py b/ICON/icon_registration/networks.py index 0700cab7..5f476459 100644 --- a/ICON/icon_registration/networks.py +++ b/ICON/icon_registration/networks.py @@ -138,7 +138,7 @@ def forward(self, x): ] x = torch.cat([x, skips[depth]], 1) x = self.lastConv(x) - return x * 5 + return x * 10 class UNetDenseMiddle(nn.Module): diff --git a/ICON/icon_registration/registration_module.py b/ICON/icon_registration/registration_module.py index b568d36f..4baaf4b0 100644 --- a/ICON/icon_registration/registration_module.py +++ b/ICON/icon_registration/registration_module.py @@ -20,11 +20,6 @@ def forward(self, image: torch.tensor, ddf: torch.tensor): if spatial_dims not in (2, 3): raise notimplementederror(f"got unsupported spatial_dims={spatial_dims}, currently support 2 or 3.") ddf_shape = (image.shape[0], spatial_dims) + tuple(image.shape[2:]) - if ddf.shape != ddf_shape: - raise valueerror( - f"given input {spatial_dims}-d image shape {image.shape}, the input ddf shape must be {ddf_shape}, " - f"got {ddf.shape} instead." - ) grid = ddf # assume grid = grid.permute([0] + list(range(2, 2 + spatial_dims)) + [1]) # (batch, ..., spatial_dims) From 43ccf0f1e48c19e4c2cf11edb2f828e78ac1cc26 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 17 Feb 2024 22:07:03 +0000 Subject: [PATCH 14/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- ICON/icon_registration/networks.py | 2 -- ICON/icon_registration/registration_module.py | 2 +- ICON/icon_registration/visualize.py | 2 +- ICON/setup.cfg | 1 - 4 files changed, 2 insertions(+), 5 deletions(-) diff --git a/ICON/icon_registration/networks.py b/ICON/icon_registration/networks.py index 5f476459..e1008071 100644 --- a/ICON/icon_registration/networks.py +++ b/ICON/icon_registration/networks.py @@ -378,5 +378,3 @@ def forward(self, x): else: raise ArgumentError() return x - - diff --git a/ICON/icon_registration/registration_module.py b/ICON/icon_registration/registration_module.py index 4baaf4b0..68835798 100644 --- a/ICON/icon_registration/registration_module.py +++ b/ICON/icon_registration/registration_module.py @@ -20,7 +20,7 @@ def forward(self, image: torch.tensor, ddf: torch.tensor): if spatial_dims not in (2, 3): raise notimplementederror(f"got unsupported spatial_dims={spatial_dims}, currently support 2 or 3.") ddf_shape = (image.shape[0], spatial_dims) + tuple(image.shape[2:]) - grid = ddf # assume + grid = ddf # assume grid = grid.permute([0] + list(range(2, 2 + spatial_dims)) + [1]) # (batch, ..., spatial_dims) for i, dim in enumerate(grid.shape[1:-1]): diff --git a/ICON/icon_registration/visualize.py b/ICON/icon_registration/visualize.py index af729faf..fbfa3d12 100644 --- a/ICON/icon_registration/visualize.py +++ b/ICON/icon_registration/visualize.py @@ -25,7 +25,7 @@ def plot_registration_result(image_A, image_B, net) : - + plt.subplot(2, 2, 1) show(image_A) plt.subplot(2, 2, 2) diff --git a/ICON/setup.cfg b/ICON/setup.cfg index 297272fd..30bb9b51 100644 --- a/ICON/setup.cfg +++ b/ICON/setup.cfg @@ -28,4 +28,3 @@ install_requires = [options.packages.find] where = src - From 0de06490e7570e11c5a9ecbfddf4c9a5dce12536 Mon Sep 17 00:00:00 2001 From: HastingsGreer Date: Tue, 15 Oct 2024 10:43:47 -0400 Subject: [PATCH 15/15] Update losses.py Signed-off-by: HastingsGreer --- ICON/icon_registration/losses.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ICON/icon_registration/losses.py b/ICON/icon_registration/losses.py index e61aaa9f..f365ac89 100644 --- a/ICON/icon_registration/losses.py +++ b/ICON/icon_registration/losses.py @@ -373,7 +373,7 @@ def forward(self, image_A, image_B): velocity_fields = phi_AB["velocity_fields"] regularization_loss = 0 for v in velocity_fields: - regularization_loss += compute_regularizer(self, phi_AB_vectorfield) + regularization_loss += self.compute_regularizer(self, v + self.identity_map) all_loss = self.lmbda * regularization_loss + similarity_loss @@ -410,7 +410,7 @@ def forward(self, image_A, image_B): velocity_fields = phi_AB["velocity_fields"] regularization_loss = 0 for v in velocity_fields: - regularization_loss += compute_regularizer(self, phi_AB_vectorfield) + regularization_loss += self.compute_regularizer(self, v) all_loss = self.lmbda * regularization_loss + similarity_loss