diff --git a/CerebNet/data_loader/augmentation.py b/CerebNet/data_loader/augmentation.py index bf30cb66d..f5dc85bb7 100644 --- a/CerebNet/data_loader/augmentation.py +++ b/CerebNet/data_loader/augmentation.py @@ -186,10 +186,14 @@ class RandomBiasField: Based on https://github.com/fepegar/torchio - It was implemented in NiftyNet by Carole Sudre and used in - `Sudre et al., 2017, Longitudinal segmentation of age-related - white matter hyperintensities - `_. + Notes + ----- + It was implemented in NiftyNet by Carole Sudre and used in [1]_. + + References + ---------- + .. [1] Sudre et al., 2017, Longitudinal segmentation of age-related white matter hyperintensities + _. """ def __init__( self, @@ -361,9 +365,7 @@ def sample_intensity_stats_from_image( classes_list = np.array(classes_list, dtype="int") else: classes_list = np.arange(labels_list.shape[0]) - assert len(classes_list) == len( - labels_list - ), "labels and classes lists should have the same length" + assert len(classes_list) == len(labels_list), "labels and classes lists should have the same length" # get unique classes unique_classes, unique_indices = np.unique(classes_list, return_index=True) n_classes = len(unique_classes) diff --git a/CerebNet/data_loader/dataset.py b/CerebNet/data_loader/dataset.py index a9e3fcf45..4f865f484 100644 --- a/CerebNet/data_loader/dataset.py +++ b/CerebNet/data_loader/dataset.py @@ -17,7 +17,6 @@ from typing import Literal, TypeVar import h5py -import nibabel as nib import numpy as np import torch from numpy import typing as npt @@ -34,7 +33,7 @@ transform_axial, transform_sagittal, ) -from FastSurferCNN.utils import Plane, logging +from FastSurferCNN.utils import Plane, logging, nibabelImage ROIKeys = Literal["source_shape", "offsets", "target_shape"] LocalizerROI = dict[ROIKeys, tuple[int, ...]] @@ -236,8 +235,8 @@ class SubjectDataset(Dataset): def __init__( self, - img_org: nib.analyze.SpatialImage, - brain_seg: nib.analyze.SpatialImage, + img_org: nibabelImage, + brain_seg: nibabelImage, patch_size: tuple[int, ...], slice_thickness: int, primary_slice: str | None = None, diff --git a/CerebNet/data_loader/loader.py b/CerebNet/data_loader/loader.py index cd120e569..cb20fe184 100644 --- a/CerebNet/data_loader/loader.py +++ b/CerebNet/data_loader/loader.py @@ -23,16 +23,26 @@ logger = logging.get_logger(__name__) -def get_dataloader(cfg, mode): +def get_dataloader(cfg: object, mode: str) -> DataLoader: """ - Creating the dataset and pytorch data loader + Create the dataset and pytorch data loader. - Args: - cfg: - mode: loading data for train, val and test mode + Parameters + ---------- + cfg : object + Configuration object containing data loading parameters. + mode : str + Loading mode - either 'train' or 'val'. - Returns: - the Dataloader + Returns + ------- + DataLoader + PyTorch DataLoader configured based on the mode. + + Raises + ------ + ValueError + If mode is not 'train' or 'val'. """ if mode == "train": diff --git a/CerebNet/datasets/load_data.py b/CerebNet/datasets/load_data.py index d0a26dfb8..fce47bbdc 100644 --- a/CerebNet/datasets/load_data.py +++ b/CerebNet/datasets/load_data.py @@ -141,14 +141,23 @@ def load_test_subject(self, current_subject): def _get_roi_extracted_data(self, img, label, talairach): """ - Finding the bounding volume and returning extracted img and label - according to roi - Args: - img: - label: - - Returns: - img and label resized according to roi and patch size + Finding the bounding volume and returning extracted img and label according to roi. + + Parameters + ---------- + img : np.ndarray + Input image volume + label : np.ndarray + Input label volume + talairach : np.ndarray or None + Talairach coordinates array + + Returns + ------- + tuple[np.ndarray, np.ndarray, np.ndarray or None] + - Resized image according to ROI and patch size + - Resized label according to ROI and patch size + - Resized and normalized talairach coordinates if provided, None otherwise """ roi = utils.bounding_volume(label, self.patch_size) img = utils.map_size(img[roi], self.patch_size) @@ -171,12 +180,17 @@ def _get_roi_extracted_data(self, img, label, talairach): def _load_auxiliary_data(self, aux_subjects_path): """ - Loading auxiliary data create by registration of original images - Args: - subjects_path: list of full path to auxiliary data + Loading auxiliary data create by registration of original images. - Returns: - dictionary with list of warped images and labels + Parameters + ---------- + aux_subjects_path : list + List of full paths to auxiliary data. + + Returns + ------- + dict + Dictionary containing lists of warped images and labels. """ aux_data = {"auxiliary_img": [], "auxiliary_lbl": []} for t1_path, lbl_path in aux_subjects_path: diff --git a/CerebNet/datasets/utils.py b/CerebNet/datasets/utils.py index 526cd4dfb..bc2024ee3 100644 --- a/CerebNet/datasets/utils.py +++ b/CerebNet/datasets/utils.py @@ -15,16 +15,14 @@ # IMPORTS from collections.abc import Sequence -from pathlib import Path -from typing import Literal, TypedDict, TypeVar +from typing import TypeVar import nibabel as nib import numpy as np import torch -from numpy import typing as npt from FastSurferCNN.data_loader.conform import getscale, scalecrop -from FastSurferCNN.utils import logging +from FastSurferCNN.utils import ShapeType, logging, nibabelImage logger = logging.getLogger(__name__) @@ -63,33 +61,6 @@ subseg_labels = {"cereb_subseg": np.array(list(CLASS_NAMES.values()))} AT = TypeVar("AT", np.ndarray, torch.Tensor) -AffineMatrix4x4 = np.ndarray[tuple[Literal[4], Literal[4]], np.dtype[float]] - - -class LTADict(TypedDict): - type: int - nxforms: int - mean: list[float] - sigma: float - lta: AffineMatrix4x4 - src_valid: int - src_filename: str - src_volume: list[int] - src_voxelsize: list[float] - src_xras: list[float] - src_yras: list[float] - src_zras: list[float] - src_cras: list[float] - dst_valid: int - dst_filename: str - dst_volume: list[int] - dst_voxelsize: list[float] - dst_xras: list[float] - dst_yras: list[float] - dst_zras: list[float] - dst_cras: list[float] - src: npt.NDArray[float] - dst: npt.NDArray[float] def define_size(mov_dim, ref_dim): @@ -273,13 +244,13 @@ def rescale_image(img_data): return new_data -def load_reorient(img_filename: str) -> nib.analyze.SpatialImage: +def load_reorient(img_filename: str) -> nibabelImage: img_file = nib.load(img_filename) canonical_img = nib.as_closest_canonical(img_file) return canonical_img -def load_reorient_lia(img_filename: str) -> nib.analyze.SpatialImage: +def load_reorient_lia(img_filename: str) -> nibabelImage: return load_reorient(img_filename).as_reoriented([[1, -1], [0, -1], [2, 1]]) @@ -356,78 +327,10 @@ def apply_warp_field(dform_field, img, interpol_order=3): return deformed_img -def read_lta(file: Path | str) -> LTADict: - """Read the LTA info.""" - import re - from functools import partial - - import numpy as np - parameter_pattern = re.compile("^\\s*([^=]+)\\s*=\\s*([^#]*)\\s*(#.*)") - vol_info_pattern = re.compile("^(.*) volume info$") - shape_pattern = re.compile("^(\\s*\\d+)+$") - matrix_pattern = re.compile("^(-?\\d+\\.\\S+\\s+)+$") - - _Type = TypeVar("_Type", bound=type) - - def _vector(_a: str, dtype: type[_Type] = float, count: int = -1) -> list[_Type]: - return np.fromstring(_a, dtype=dtype, count=count, sep=" ").tolist() - - parameters = { - "type": int, - "nxforms": int, - "mean": partial(_vector, dtype=float, count=3), - "sigma": float, - "subject": str, - "fscale": float, - } - vol_info_par = { - "valid": int, - "filename": str, - "volume": partial(_vector, dtype=int, count=3), - "voxelsize": partial(_vector, dtype=float, count=3), - **{f"{c}ras": partial(_vector, dtype=float) for c in "xyzc"} - } - - with open(file) as f: - lines = f.readlines() - - items = [] - shape_lines = [] - matrix_lines = [] - section = "" - for i, line in enumerate(lines): - if line.strip() == "": - continue - if hits := parameter_pattern.match(line): - name = hits.group(1) - if section and name in vol_info_par: - items.append((f"{section}_{name}", vol_info_par[name](hits.group(2)))) - elif name in parameters: - section = "" - items.append((name, parameters[name](hits.group(2)))) - else: - raise NotImplementedError(f"Unrecognized type string in lta-file " - f"{file}:{i+1}: '{name}'") - elif hits := vol_info_pattern.match(line): - section = hits.group(1) - # not a parameter line - elif shape_pattern.search(line): - shape_lines.append(np.fromstring(line, dtype=int, count=-1, sep=" ")) - elif matrix_pattern.search(line): - matrix_lines.append(np.fromstring(line, dtype=float, count=-1, sep=" ")) - - shape_lines = list(map(tuple, shape_lines)) - lta = dict(items) - if lta["nxforms"] != len(shape_lines): - raise OSError("Inconsistent lta format: nxforms inconsistent with shapes.") - if len(shape_lines) > 1 and np.any(np.not_equal([shape_lines[0]], shape_lines[1:])): - raise OSError(f"Inconsistent lta format: shapes inconsistent {shape_lines}") - lta_matrix = np.asarray(matrix_lines).reshape((-1,) + shape_lines[0].shape) - lta["lta"] = lta_matrix - return lta - - def load_talairach_coordinates(tala_path, img_shape, vox2ras): + """Load talairach coordinates from file.""" + from FastSurferCNN.utils.lta import read_lta + tala_lta = read_lta(tala_path) # create image grid p x, y, z = np.meshgrid( @@ -448,7 +351,9 @@ def load_talairach_coordinates(tala_path, img_shape, vox2ras): return tala_coordinates -def normalize_array(arr): +def normalize_array(arr: np.ndarray[ShapeType, np.dtype[np.number]]) \ + -> np.ndarray[ShapeType, np.dtype[np.floating]]: + """Normalize the data array to [0, 1].""" min = arr.min() max = arr.max() diff --git a/CerebNet/inference.py b/CerebNet/inference.py index 768e87aeb..e871be255 100644 --- a/CerebNet/inference.py +++ b/CerebNet/inference.py @@ -18,7 +18,6 @@ from pathlib import Path from typing import TYPE_CHECKING -import nibabel as nib import numpy as np import pandas as pd import torch @@ -30,7 +29,7 @@ from CerebNet.models.networks import build_model from CerebNet.utils import checkpoint as cp from FastSurferCNN.data_loader.conform import crop_transform -from FastSurferCNN.utils import PLANES, Plane, logging +from FastSurferCNN.utils import PLANES, Plane, logging, nibabelImage from FastSurferCNN.utils.arg_types import ImageSizeOption, OrientationType from FastSurferCNN.utils.common import SubjectDirectory, SubjectList, find_device from FastSurferCNN.utils.mapper import JsonColorLookupTable, Mapper, TSVLookupTable @@ -260,12 +259,16 @@ def _convert(plane: Plane) -> torch.Tensor: def _view_aggregation(self, logits: dict[Plane, torch.Tensor]) -> torch.Tensor: """ Aggregate the view (axial, coronal, sagittal) into one volume and get the - class of the largest probability. (argmax) + class of the largest probability (argmax). - Args: - logits: dictionary of per plane predicted logits (axial, coronal, sagittal) + Parameters + ---------- + logits : dict[Plane, torch.Tensor] + Dictionary of per plane predicted logits (axial, coronal, sagittal) - Returns: + Returns + ------- + torch.Tensor Tensor of classes (of largest aggregated logits) """ aggregated_logits = torch.add( @@ -328,7 +331,7 @@ def _save_cerebnet_seg( self, cerebnet_seg: np.ndarray, filename: str | Path, - orig: nib.analyze.SpatialImage + orig: nibabelImage, ) -> "Future[None]": """ Saving the segmentations asynchronously. @@ -339,7 +342,7 @@ def _save_cerebnet_seg( Segmentation data. filename : Path, str Path and file name to the saved file. - orig : nib.analyze.SpatialImage + orig : nibabelImage File container (with header and affine) used to populate header and affine of the segmentation. diff --git a/CerebNet/utils/lr_scheduler.py b/CerebNet/utils/lr_scheduler.py index e78297f1b..813696505 100644 --- a/CerebNet/utils/lr_scheduler.py +++ b/CerebNet/utils/lr_scheduler.py @@ -41,11 +41,20 @@ def __init__(self, optimizer, *args, T_0=10, Tmult=1, lr_restart=None, **kwargs) and a number, it is reset to initial lr * (lr_restart) ^ i, if lr_restart is a function, the lr gets reset to lr_restart(initial_lr, i). - Args: - ...: same as ReduceLROnPlateau - T_0 (optional): number of epochs until first restart (default: 10) - Tmult (optional): multiplicative factor for future restarts (default: 1) - lr_restart (optinoal): multiplicative factor for learning rate adjustment at restart. + Parameters + ---------- + optimizer : torch.optim.Optimizer + Wrapped optimizer + *args + Arguments passed to ReduceLROnPlateau + T_0 : int, optional + Number of epochs until first restart. Default is 10 + Tmult : int, optional + Multiplicative factor for future restarts. Default is 1 + lr_restart : float or callable, optional + Multiplicative factor for learning rate adjustment at restart + **kwargs + Keyword arguments passed to ReduceLROnPlateau """ # from torch.optim.lr_scheduler._LRSchduler # if last_epoch == -1: @@ -175,15 +184,21 @@ def _get_warmup_factor_at_iter( Return the learning rate warmup factor at a specific iteration. See :paper:`in1k1h` for more details. - Args: - method (str): warmup method; either "constant" or "linear". - iter (int): iteration at which to calculate the warmup factor. - warmup_iters (int): the number of warmup iterations. - warmup_factor (float): the base warmup factor (the meaning changes according - to the method used). - - Returns: - float: the effective warmup factor at the given iteration. + Parameters + ---------- + method : str + Warmup method; either "constant" or "linear" + iter : int + Iteration at which to calculate the warmup factor + warmup_iters : int + The number of warmup iterations + warmup_factor : float + The base warmup factor (the meaning changes according to the method used) + + Returns + ------- + float + The effective warmup factor at the given iteration """ if iter >= warmup_iters: return 1.0 diff --git a/CorpusCallosum/README.md b/CorpusCallosum/README.md new file mode 100644 index 000000000..9e945e35d --- /dev/null +++ b/CorpusCallosum/README.md @@ -0,0 +1,16 @@ +# Corpus Callosum Pipeline + +A deep learning-based pipeline for automated segmentation, analysis, and shape analysis of the corpus callosum in brain MRI scans. +Also segments the fornix, localizes the anterior and posterior commissure (AC and PC) and standardizes the orientation of the brain. + +For detailed documentation, please refer to: +- [Module Overview](../doc/overview/modules/CC.md): Detailed description of the pipeline, workflow, and analysis options. +- [Output Files](../doc/overview/OUTPUT_FILES.md#corpus-callosum-module): List of output files and their descriptions. + +## Quickstart + +```bash +python3 fastsurfer_cc.py --sd /path/to/fastsurfer/output --sid test-case --verbose +``` + +Gives all standard outputs. The corpus callosum morphometry can be found at `stats/callosum.CC.midslice.json` including 100 thickness measurements and the areas of sub-segments. diff --git a/CorpusCallosum/__init__.py b/CorpusCallosum/__init__.py new file mode 100644 index 000000000..63db725af --- /dev/null +++ b/CorpusCallosum/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = [ + "data", + "segmentation", + "transforms", + "utils", +] diff --git a/CorpusCallosum/cc_visualization.py b/CorpusCallosum/cc_visualization.py new file mode 100644 index 000000000..75f2e753d --- /dev/null +++ b/CorpusCallosum/cc_visualization.py @@ -0,0 +1,241 @@ +import argparse +import sys +from pathlib import Path +from typing import Literal + +import numpy as np + +from CorpusCallosum.data.constants import FSAVERAGE_DATA_PATH +from CorpusCallosum.data.fsaverage_cc_template import load_fsaverage_cc_template +from CorpusCallosum.data.read_write import load_fsaverage_data +from CorpusCallosum.shape.contour import CCContour +from CorpusCallosum.shape.mesh import create_CC_mesh_from_contours +from FastSurferCNN.utils.logging import get_logger, setup_logging + +logger = get_logger(__name__) + + +def make_parser() -> argparse.ArgumentParser: + """Create a command line parser for the visualization pipeline.""" + parser = argparse.ArgumentParser(description="Visualize corpus callosum from template files.") + parser.add_argument( + "--template_dir", + type=str, + required=True, + help=( + "Path to a template directory containing per-slice files named " + "thickness_values_.txt, and optionally contour_.txt " + "and thickness_measurement_points_.txt. If contour_.txt " + "and thickness_measurement_points_.txt are not provided, " + "uses fsaverage template." + ), + metavar="TEMPLATE_DIR", + default=None, + ) + parser.add_argument("--output_dir", + type=str, + required=True, + help="Directory for output files. Writes: " + "cc_mesh.html - Interactive 3D mesh visualization (HTML file) " + "midslice_2d.png - 2D midslice visualization of the corpus callosum " + "cc_mesh.vtk - VTK mesh file format " + "cc_mesh.fssurf - FreeSurfer surface file " + "cc_mesh_overlay.curv - FreeSurfer curvature overlay file " + "cc_mesh_snap.png - Screenshot/snapshot of the 3D mesh (requires whippersnappy>=1.3.1)", + metavar="OUTPUT_DIR" + ) + parser.add_argument( + "--resolution", + type=float, + default=1.0, + help="Resolution in mm for the mesh.", + metavar="RESOLUTION" + ) + parser.add_argument( + "--smoothing_window", + type=int, + default=5, + help="Window size for smoothing the contour.", + metavar="SMOOTHING_WINDOW" + ) + parser.add_argument( + "--colormap", + type=str, + default="red_to_yellow", + choices=["red_to_blue", "blue_to_red", "red_to_yellow", "yellow_to_red"], + help="Colormap to use for thickness visualization, lower to higher values.", + ) + parser.add_argument( + "--color_range", + type=float, + nargs=2, + default=None, + metavar=("MIN", "MAX"), + required=False, + help="Specify the range for the colorbar (2 values: min max). Defaults to automatic choice. \ + (e.g. --color_range 0 10).", + ) + parser.add_argument( + "--legend", + type=str, + default="Thickness (mm)", + help="Legend for the colorbar.", + metavar="LEGEND") + parser.add_argument( + "--twoD", + action="store_true", + help="Generate 2D visualization instead of 3D mesh.", + ) + parser.add_argument( + "-v", + "--verbose", + action="count", + default=0, + help="Enable verbose (pass twice for debug-output).", + ) + return parser + + +def options_parse() -> argparse.Namespace: + """Parse command line arguments for the pipeline.""" + parser = make_parser() + args = parser.parse_args() + + # Create output directory if it doesn't exist + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + return args + + +def load_contours_from_template_dir( + template_dir: Path, resolution: float, smoothing_window: int +) -> list[CCContour]: + """Load all contours and thickness data from a template directory.""" + thickness_files = sorted(template_dir.glob("thickness_values_*.txt")) + if not thickness_files: + raise FileNotFoundError( + f"No thickness files found in template directory {template_dir}. " + "Expected files named thickness_values_.txt and " + "optionally contour_.txt and thickness_measurement_points_.txt." + ) + + fsaverage_contour = None + contours: list[CCContour] = [] + for thickness_file in thickness_files: + try: + idx = int(thickness_file.stem.split("_")[-1]) + except ValueError: + # skip files that do not follow the expected naming + continue + + contour_file = template_dir / f"contour_{idx}.txt" + + if not contour_file.exists(): + # get length of thickness values + thickness_values = np.loadtxt(thickness_file, dtype=str) + # get the non nan thickness values (excluding header), so we know how many points to sample + num_thickness_values = np.sum(~np.isnan(np.array(thickness_values[1:],dtype=float))) + if fsaverage_contour is None: + fsaverage_contour = load_fsaverage_cc_template() + # create measurement points (points = 2 x levelpaths) accorindg to number of thickness values + fsaverage_contour.create_levelpaths(num_points=num_thickness_values // 2, update_data=True) + current_contour = fsaverage_contour.copy() + current_contour.load_thickness_values(thickness_file) + + else: + # this is kinda ugly - maybe we need to overload the constructor to load the contour and thickness values? + current_contour = CCContour(np.empty((0, 2)), np.empty((0,)), resolution=resolution) + current_contour.load_contour(contour_file) + current_contour.load_thickness_values(thickness_file) + + current_contour.fill_thickness_values() + contours.append(current_contour) + + if not contours: + raise ValueError(f"No valid contours could be loaded from {template_dir}") + return contours + + +def main( + template_dir: str | Path, + output_dir: str | Path, + resolution: float = 1.0, + smoothing_window: int = 5, + colormap: str = "red_to_yellow", + color_range: tuple[float, float] | None = None, + legend: str | None = None, + twoD: bool = False, +) -> Literal[0] | str: + """Visualize corpus callosum templates in 2D or 3D.""" + output_dir = Path(output_dir) + color_range = tuple(color_range) if color_range is not None else None + + _, _, vox2ras_tkr = load_fsaverage_data(FSAVERAGE_DATA_PATH) + + contours = load_contours_from_template_dir( + Path(template_dir), resolution=resolution, smoothing_window=smoothing_window, + ) + + # 2D visualization + mid_contour = contours[len(contours) // 2] + + # for now, we only support thickness visualization, this is preparing to plot also p-values and icc values + mode = "thickness" + logger.info(f"Writing output to {output_dir / 'cc_thickness_2d.png'}") + + if mode == "thickness": + raw_thickness_values = mid_contour.thickness_values[~np.isnan(mid_contour.thickness_values)] + # values are duplicated because we they have two measurement points per levelpath + raw_thickness_values = raw_thickness_values[len(raw_thickness_values) // 2:] + mid_contour.plot_contour_colorfill( + plot_values=raw_thickness_values, + title=None, + save_path=str(output_dir / "cc_thickness_2d.png"), + colorbar=True, + mode=mode + ) + if twoD: + return 0 + + # 3D visualization + cc_mesh = create_CC_mesh_from_contours(contours, smooth=0) + + plot_kwargs = dict( + colormap=colormap, + color_range=color_range, + thickness_overlay=True, + legend=legend or "", + ) + cc_mesh.plot_mesh(**plot_kwargs) + cc_mesh.plot_mesh(output_path=str(output_dir / "cc_mesh.html"), **plot_kwargs) + + cc_mesh = cc_mesh.to_fs_coordinates(vox2ras_tkr=vox2ras_tkr) + logger.info(f"Writing vtk file to {output_dir / 'cc_mesh.vtk'}") + cc_mesh.write_vtk(str(output_dir / "cc_mesh.vtk")) + logger.info(f"Writing freesurfer surface file to {output_dir / 'cc_mesh.fssurf'}") + cc_mesh.write_fssurf(str(output_dir / "cc_mesh.fssurf")) + logger.info(f"Writing freesurfer overlay file to {output_dir / 'cc_mesh_overlay.curv'}") + cc_mesh.write_morph_data(str(output_dir / "cc_mesh_overlay.curv")) + try: + cc_mesh.snap_cc_picture(str(output_dir / "cc_mesh_snap.png")) + logger.info(f"Writing 3D snapshot image to {output_dir / 'cc_mesh_snap.png'}") + except RuntimeError: + logger.warning("The cc_visualization script requires whippersnappy>=1.3.1 to makes screenshots, install with " + "`pip install whippersnappy>=1.3.1` !") + return 0 + +if __name__ == "__main__": + options = options_parse() + + # Set up logging if verbose mode is enabled + setup_logging(None, options.verbose) # Log to stdout only + + sys.exit(main( + template_dir=options.template_dir, + output_dir=options.output_dir, + resolution=options.resolution, + smoothing_window=options.smoothing_window, + colormap=options.colormap, + color_range=options.color_range, + legend=options.legend, + twoD=options.twoD, + )) diff --git a/CorpusCallosum/config/checkpoint_paths.yaml b/CorpusCallosum/config/checkpoint_paths.yaml new file mode 100644 index 000000000..ca78b7da2 --- /dev/null +++ b/CorpusCallosum/config/checkpoint_paths.yaml @@ -0,0 +1,7 @@ +url: +- "https://zenodo.org/records/17141933/files" +- "https://b2share.fz-juelich.de/api/files/e4eb699c-ba68-4470-9f3d-89ceeee1a334" + +checkpoint: + segmentation: "checkpoints/FastSurferCC_segmentation_v1.0.0.pkl" + localization: "checkpoints/FastSurferCC_localization_v1.0.0.pkl" diff --git a/CorpusCallosum/data/__init__.py b/CorpusCallosum/data/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/CorpusCallosum/data/constants.py b/CorpusCallosum/data/constants.py new file mode 100644 index 000000000..745809313 --- /dev/null +++ b/CorpusCallosum/data/constants.py @@ -0,0 +1,57 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from FastSurferCNN.utils.parser_defaults import FASTSURFER_ROOT + +### Constants +WEIGHTS_PATH = FASTSURFER_ROOT / "checkpoints" +FSAVERAGE_CENTROIDS_PATH = FASTSURFER_ROOT / "CorpusCallosum" / "data" / "fsaverage_centroids.json" +# Contains both affine and header +FSAVERAGE_DATA_PATH = FASTSURFER_ROOT / "CorpusCallosum" / "data" / "fsaverage_data.json" +FSAVERAGE_MIDDLE = 128 # Middle slice index in fsaverage space +CC_LABEL = 192 # Label value for corpus callosum in segmentation +FORNIX_LABEL = 250 # Label value for fornix in segmentation +THIRD_VENTRICLE_LABEL = 4 # Label value for third ventricle in segmentation +SUBSEGMENT_LABELS = [251, 252, 253, 254, 255] # labels for subsegments in segmentation + + +DEFAULT_INPUT_PATHS = { + "conf_name": "mri/orig.mgz", + "aseg_name": "mri/aparc.DKTatlas+aseg.deep.mgz", +} + +DEFAULT_OUTPUT_PATHS = { + ## images + "upright_volume": None, # orig.mgz mapped to upright space + ## segmentations + "segmentation": "mri/callosum.CC.upright.mgz", # corpus callosum segmentation in upright space + "segmentation_in_orig": "mri/callosum.CC.orig.mgz", # cc segmentation in input segmentations space + "softlabels_cc": "mri/callosum.CC.soft.mgz", # cc softlabels in upright space + "softlabels_fn": "mri/fornix.CC.soft.mgz", # fornix softlabels in upright space + "softlabels_background": "mri/background.CC.soft.mgz", # background softlabels in upright space + ## stats + "cc_markers": "stats/callosum.CC.midslice.json", # cc metrics for middle slice + "cc_measures": "stats/callosum.CC.all_slices.json", # cc metrics for all slices + ## transforms + "upright_lta": "mri/transforms/cc_up.lta", # lta transform from orig to upright space + "orient_volume_lta": "mri/transforms/orient_volume.lta", # lta transform from orig to upright+acpc corrected space + ## qc + "qc_image": None, #"callosum.png", # debug image of cc contours + "thickness_image": None, # "callosum.thickness.png", # whippersnappy 3D image of cc thickness + "cc_html": None, # "corpus_callosum.html", # plotly cc visualization + ## surface + "cc_surf": "surf/callosum.surf", # cc surface file + "cc_thickness_overlay": "surf/callosum.thickness.w", # cc surface overlay file + "cc_surf_vtk": "surf/callosum.vtk", # vtk file of cc mesh +} \ No newline at end of file diff --git a/CorpusCallosum/data/fsaverage_cc_template.py b/CorpusCallosum/data/fsaverage_cc_template.py new file mode 100644 index 000000000..55d11bc26 --- /dev/null +++ b/CorpusCallosum/data/fsaverage_cc_template.py @@ -0,0 +1,154 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from pathlib import Path + +import nibabel as nib +import numpy as np +from scipy import ndimage + +from CorpusCallosum.data import constants +from CorpusCallosum.shape.contour import CCContour +from CorpusCallosum.shape.postprocessing import recon_cc_surf_measure +from FastSurferCNN.utils.brainvolstats import mask_in_array + + +def smooth_contour(contour: tuple[np.ndarray, np.ndarray], window_size: int = 5) -> tuple[np.ndarray, np.ndarray]: + """Smooth a contour using a moving average filter. + + Parameters + ---------- + contour : tuple of arrays + The contour coordinates (x, y). + window_size : int + Size of the smoothing window. + + Returns + ------- + tuple of arrays + The smoothed contour coordinates (x, y). + + """ + x, y = contour + + # Ensure the window size is odd + if window_size % 2 == 0: + window_size += 1 + + # Create a padded version of the arrays to handle the edges + x_padded = np.pad(x, (window_size//2, window_size//2), mode='wrap') + y_padded = np.pad(y, (window_size//2, window_size//2), mode='wrap') + + # Apply moving average + x_smoothed = np.zeros_like(x) + y_smoothed = np.zeros_like(y) + + for i in range(len(x)): + x_smoothed[i] = np.mean(x_padded[i:i+window_size]) + y_smoothed[i] = np.mean(y_padded[i:i+window_size]) + + return (x_smoothed, y_smoothed) + + +def load_fsaverage_cc_template() -> tuple[ + np.ndarray, tuple[np.ndarray, np.ndarray], np.ndarray, np.ndarray, np.ndarray, tuple[int, int] +]: + """Load and process the fsaverage corpus callosum template. + + This function loads the fsaverage segmentation from FreeSurfer's data directory, + extracts the corpus callosum mask, and processes it to create a smooth template. + + Returns + ------- + tuple + Contains: + - contour : tuple[np.ndarray, np.ndarray] : x and y coordinates of the contour points. + - anterior_endpoint_idx : np.ndarray : Index of the anterior endpoint. + - posterior_endpoint_idx : np.ndarray : Index of the posterior endpoint. + + Raises + ------ + OSError + If FREESURFER_HOME environment variable is not set correctly. + + """ + # smooth outside contour + # Apply smoothing to the outside contour using a moving average + + try: + freesurfer_home = Path(os.environ['FREESURFER_HOME']) + except KeyError as err: + raise OSError(f"FREESURFER_HOME environment variable is not set correctly or does not exist: " + f"{freesurfer_home}, either provide your own template or set the " + f"FREESURFER_HOME environment variable") from err + + fsaverage_seg_path = freesurfer_home / 'subjects' / 'fsaverage' / 'mri' / 'aparc+aseg.mgz' + fsaverage_seg = nib.load(fsaverage_seg_path) + segmentation = np.asarray(fsaverage_seg.dataobj) + + PC = np.array([131, 99]) + AC = np.array([135, 130]) + + + midslice = segmentation.shape[0]//2 +1 + + cc_mask = mask_in_array(segmentation[midslice], constants.SUBSEGMENT_LABELS) + + # Smooth the CC mask to reduce noise and irregularities + + # Apply binary closing to fill small holes + cc_mask_smoothed = ndimage.binary_closing(cc_mask, structure=np.ones((3, 3))) + + # Apply binary opening to remove small isolated pixels + cc_mask_smoothed = ndimage.binary_opening(cc_mask_smoothed, structure=np.ones((2, 2))) + + # Apply Gaussian smoothing and threshold to get a binary mask again + cc_mask_smoothed = ndimage.gaussian_filter(cc_mask_smoothed.astype(float), sigma=0.8) + cc_mask_smoothed = cc_mask_smoothed > 0.5 + + # Use the smoothed mask for further processing + cc_mask = cc_mask_smoothed.astype(int) * 192 + + _, contour_with_thickness, (anterior_endpoint_idx, posterior_endpoint_idx) = recon_cc_surf_measure( + segmentation=cc_mask[None], + slice_idx=0, + ac_coords=AC, + pc_coords=PC, + affine=fsaverage_seg.affine, + num_thickness_points=100, + subdivisions=[1/6, 1/2, 2/3, 3/4], + subdivision_method="shape", + contour_smoothing=5, + vox_size=(1., 1., 1.), # fsaverage is in 1mm isotropic + ) + outside_contour = contour_with_thickness[:,:2].T + + # make sure the CC stays in shape despite smoothing by moving endpoints outwards + outside_contour[0,anterior_endpoint_idx] -= 55 + outside_contour[0,posterior_endpoint_idx] += 30 + + # Apply smoothing to the outside contour + outside_contour_smoothed = smooth_contour(outside_contour, window_size=11) + outside_contour_smoothed = smooth_contour(outside_contour_smoothed, window_size=15) + outside_contour_smoothed = smooth_contour(outside_contour_smoothed, window_size=30) + outside_contour = outside_contour_smoothed + + fsaverage_contour = CCContour(np.array(outside_contour).T, + np.zeros(len(outside_contour[0])), + endpoint_idxs=(anterior_endpoint_idx, posterior_endpoint_idx), + resolution=1.0) + + + return fsaverage_contour diff --git a/CorpusCallosum/data/fsaverage_centroids.json b/CorpusCallosum/data/fsaverage_centroids.json new file mode 100644 index 000000000..bccf1189d --- /dev/null +++ b/CorpusCallosum/data/fsaverage_centroids.json @@ -0,0 +1,217 @@ +{ + "2": [ + -27.242888317659038, + -22.210776052870685, + 18.546657917012894 + ], + "3": [ + -32.18990180647074, + -16.863336561239265, + 16.015058654310195 + ], + "4": [ + -14.455663189269757, + -13.693461251862885, + 13.7136736214605 + ], + "5": [ + -33.906934306569354, + -22.284671532846716, + -15.821167883211672 + ], + "7": [ + -17.305372931308085, + -53.43157258369229, + -36.01715408448575 + ], + "8": [ + -22.265822784810126, + -64.36629649763144, + -37.674831094198964 + ], + "10": [ + -11.752497096399537, + -19.87584204413473, + 5.165737514518 + ], + "11": [ + -15.034188034188048, + 9.437551695616207, + 6.913427074717404 + ], + "12": [ + -26.366197183098592, + -0.15686274509803866, + -2.091549295774655 + ], + "13": [ + -20.91671388101983, + -5.188668555240795, + -2.4107648725212414 + ], + "14": [ + 0.5832045337454872, + -11.11695002575992, + -3.9433281813498127 + ], + "15": [ + 0.5413500223513665, + -46.56236030397854, + -33.21814930710772 + ], + "16": [ + 0.8273686582297444, + -31.946261594502232, + -31.003755304367417 + ], + "17": [ + -26.088480154888686, + -24.429622458857693, + -15.148886737657307 + ], + "18": [ + -23.90932509015971, + -7.339515713549716, + -20.63575476558475 + ], + "24": [ + 0.6026785714285694, + -20.70535714285714, + 8.040736607142861 + ], + "26": [ + -9.629820051413873, + 10.960154241645256, + -8.786632390745496 + ], + "28": [ + -11.456631660832358, + -16.84694671334111, + -10.32691559704395 + ], + "30": [ + -28.545454545454533, + -3.200000000000003, + -10.181818181818187 + ], + "31": [ + -12.502610966057432, + -12.218015665796344, + 6.30548302872063 + ], + "41": [ + 27.68021284305685, + -21.297671313867227, + 18.84475807220643 + ], + "42": [ + 32.70257488842361, + -15.910019860438453, + 16.482307738602415 + ], + "43": [ + 15.18157827962446, + -13.241715300685101, + 14.257802588175593 + ], + "44": [ + 33.10191082802548, + -17.921443736730367, + -16.980891719745216 + ], + "46": [ + 19.070892410341955, + -53.51368564713019, + -35.67336416710896 + ], + "47": [ + 23.65288732176549, + -64.41682904951904, + -37.19518418854969 + ], + "49": [ + 12.493538246594483, + -19.225986727209218, + 5.663872394923743 + ], + "50": [ + 16.15939771547248, + 9.458463136033231, + 8.239096573208727 + ], + "51": [ + 26.94455762514552, + 0.5477299185099014, + -2.249126891734562 + ], + "52": [ + 22.105321507760536, + -4.939024390243901, + -1.9539911308204125 + ], + "53": [ + 27.74364210135512, + -23.379431965843693, + -14.994987933914985 + ], + "54": [ + 24.942549371633746, + -6.010771992818675, + -20.737881508079 + ], + "58": [ + 9.986789960369876, + 10.424042272126826, + -8.705416116248358 + ], + "60": [ + 12.434200157604408, + -16.41252955082743, + -10.056737588652481 + ], + "62": [ + 30.558139534883722, + -2.581395348837205, + -10.441860465116292 + ], + "63": [ + 12.008567931456554, + -11.022031823745408, + 7.3671970624235 + ], + "77": [ + -13.714285714285722, + -15.714285714285708, + 0.9285714285714306 + ], + "85": [ + 1.466019417475735, + -0.2038834951456323, + -18.466019417475735 + ], + "251": [ + 0.5403535741737073, + -35.800153727901616, + 16.784780937740194 + ], + "252": [ + 0.6063829787234027, + -18.29361702127659, + 24.748936170212772 + ], + "253": [ + 0.5847299813780324, + -2.424581005586589, + 25.815642458100555 + ], + "254": [ + 0.7008849557522154, + 11.998230088495575, + 20.40530973451328 + ], + "255": [ + 0.8761467889908232, + 24.612844036697254, + 5.411009174311928 + ] +} \ No newline at end of file diff --git a/CorpusCallosum/data/fsaverage_data.json b/CorpusCallosum/data/fsaverage_data.json new file mode 100644 index 000000000..0fdd17fbd --- /dev/null +++ b/CorpusCallosum/data/fsaverage_data.json @@ -0,0 +1,88 @@ +{ + "affine": [ + [ + -1.0, + 0.0, + 0.0, + 128.0 + ], + [ + 0.0, + 0.0, + 1.0, + -128.0 + ], + [ + 0.0, + -1.0, + 0.0, + 128.0 + ], + [ + 0.0, + 0.0, + 0.0, + 1.0 + ] + ], + "header": { + "dims": [ + 256, + 256, + 256 + ], + "delta": [ + 1.0, + 1.0, + 1.0 + ], + "Mdc": [ + [ + -1.0, + 0.0, + 0.0 + ], + [ + 0.0, + 0.0, + 10000000000.0 + ], + [ + 0.0, + -10000000000.0, + 0.0 + ] + ], + "Pxyz_c": [ + 128.0, + -128.0, + 128.0 + ] + }, + "vox2ras_tkr": [ + [ + -1.0, + 0.0, + 0.0, + 128.0 + ], + [ + 0.0, + 0.0, + 1.0, + -128.0 + ], + [ + 0.0, + -1.0, + 0.0, + 128.0 + ], + [ + 0.0, + 0.0, + 0.0, + 1.0 + ] + ] +} \ No newline at end of file diff --git a/CorpusCallosum/data/generate_fsaverage_centroids.py b/CorpusCallosum/data/generate_fsaverage_centroids.py new file mode 100644 index 000000000..b1ef7b19a --- /dev/null +++ b/CorpusCallosum/data/generate_fsaverage_centroids.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python3 +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Script to generate static fsaverage centroids file. + +This script extracts centroids from the fsaverage template segmentation +and saves them to a JSON file for fast loading during pipeline execution. +Run this script once to generate the centroids file. +""" + +import json +import os +from pathlib import Path + +import nibabel as nib +import numpy as np +from read_write import calc_ras_centroids_from_seg, convert_numpy_to_json_serializable + +import FastSurferCNN.utils.logging as logging + +logger = logging.get_logger(__name__) + + +def main() -> None: + """Generate and save fsaverage centroids to a static file. + + This script extracts centroids from the fsaverage template segmentation + and saves them to a JSON file for fast loading during pipeline execution. + + The script performs the following steps: + 1. Load fsaverage segmentation from FreeSurfer directory + 2. Extract centroids for all anatomical structures + 3. Save centroids to JSON file + 4. Extract and save affine matrix and header fields + + Raises + ------ + OSError + If FREESURFER_HOME environment variable is not set or invalid + FileNotFoundError + If required fsaverage files are not found + + Notes + ----- + The script saves two files: + - fsaverage_centroids.json : Contains centroids for each anatomical structure + - fsaverage_data.json : Contains affine matrix and header information + """ + + # Get fsaverage path from FreeSurfer environment + try: + fs_home = Path(os.environ['FREESURFER_HOME']) + if not fs_home.exists(): + raise OSError(f"FREESURFER_HOME environment variable is not set correctly or does not exist: {fs_home}") + + fsaverage_path = fs_home / 'subjects' / 'fsaverage' + if not fsaverage_path.exists(): + raise OSError(f"fsaverage path does not exist: {fsaverage_path}") + + fsaverage_aseg_path = fsaverage_path / 'mri' / 'aseg.mgz' + if not fsaverage_aseg_path.exists(): + raise FileNotFoundError(f"fsaverage aseg file does not exist: {fsaverage_aseg_path}") + + except KeyError as err: + raise OSError("FREESURFER_HOME environment variable is not set") from err + + logger.info(f"Loading fsaverage segmentation from: {fsaverage_aseg_path}") + + # Load fsaverage segmentation + fsaverage_nib = nib.load(fsaverage_aseg_path) + + # Extract centroids + logger.info("Extracting centroids from fsaverage...") + centroids_dst = calc_ras_centroids_from_seg(fsaverage_nib) + + logger.info(f"Found {len(centroids_dst)} anatomical structures with centroids") + + # Convert to JSON-serializable format + centroids_serializable = convert_numpy_to_json_serializable(centroids_dst) + + # Save centroids to JSON file + centroids_output_path = Path(__file__).parent / "fsaverage_centroids.json" + logger.info(f"Saving fsaverage centroids to {centroids_output_path}") + with open(centroids_output_path, 'w') as f: + json.dump(centroids_serializable, f, indent=2) + + logger.info(f"Fsaverage centroids saved to: {centroids_output_path}") + logger.info(f"Centroids file size: {centroids_output_path.stat().st_size} bytes") + + # Extract and save fsaverage affine matrix and header fields + logger.info("Extracting fsaverage affine matrix and header fields...") + fsaverage_affine = fsaverage_nib.affine.astype(float) # Convert to float for JSON serialization + + # Extract header fields needed for LTA + header = fsaverage_nib.header + dims = [int(x) for x in header.get_data_shape()[:3]] # Convert to int for JSON serialization + delta = [float(x) for x in header.get_zooms()[:3]] # Convert to float for JSON serialization + vox2ras = header.get_vox2ras() + + # Direction cosines matrix (Mdc) - extract rotation part without scaling + delta_diag = np.diag(delta) + # Avoid division by zero by using a small epsilon for zero values + delta_safe = np.where(delta_diag == 0, 1e-10, delta_diag) + Mdc = (vox2ras[:3, :3] / delta_safe).astype(float) # Convert to float for JSON serialization + + Pxyz_c = vox2ras[:3, 3].astype(float) # Convert to float for JSON serialization + + # Combine affine and header data + combined_data = { + "affine": fsaverage_affine.tolist(), # Convert numpy array to list for JSON serialization + "vox2ras_tkr": fsaverage_nib.header.get_vox2ras_tkr().tolist(), + "header": { + "dims": dims, + "delta": delta, + "Mdc": Mdc.tolist(), # Convert numpy array to list for JSON serialization + "Pxyz_c": Pxyz_c.tolist() # Convert numpy array to list for JSON serialization + } + } + + # Convert the entire structure to JSON-serializable format to handle any remaining numpy types + combined_data_serializable = convert_numpy_to_json_serializable(combined_data) + + # Save combined data to JSON file + combined_output_path = Path(__file__).parent / "fsaverage_data.json" + logger.info(f"Saving fsaverage affine and header data to {combined_output_path}") + with open(combined_output_path, 'w') as f: + json.dump(combined_data_serializable, f, indent=2) + + logger.info(f"Fsaverage affine and header data saved to: {combined_output_path}") + logger.info(f"Combined file size: {combined_output_path.stat().st_size} bytes") + logger.info(f"Affine matrix shape: {fsaverage_affine.shape}") + logger.info(f"Header dims: {dims}, delta: {delta}") + + # Print some statistics + label_ids = list(centroids_dst.keys()) + logger.info(f"Label IDs range: {min(label_ids)} to {max(label_ids)}") + logger.info("Sample centroids:") + for label_id in sorted(label_ids)[:5]: + centroid = centroids_dst[label_id] + logger.info(f" Label {label_id}: [{centroid[0]:.2f}, {centroid[1]:.2f}, {centroid[2]:.2f}]") + + logger.info("Fsaverage affine matrix:") + logger.info(fsaverage_affine) + + logger.info("Fsaverage header fields:") + logger.info(f" dims: {dims}") + logger.info(f" delta: {delta}") + logger.info(f" Mdc shape: {Mdc.shape}") + logger.info(f" Pxyz_c: {Pxyz_c}") + logger.info("Combined data structure created successfully") + + +if __name__ == "__main__": + main() diff --git a/CorpusCallosum/data/read_write.py b/CorpusCallosum/data/read_write.py new file mode 100644 index 000000000..e11b86324 --- /dev/null +++ b/CorpusCallosum/data/read_write.py @@ -0,0 +1,223 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from pathlib import Path +from typing import TypedDict + +import numpy as np +from numpy import typing as npt + +import FastSurferCNN.utils.logging as logging +from FastSurferCNN.utils import AffineMatrix4x4, nibabelImage +from FastSurferCNN.utils.parallel import thread_executor + + +class FSAverageHeader(TypedDict): + dims: npt.NDArray[int] + delta: npt.NDArray[float] + Mdc: npt.NDArray[float] + Pxyz_c: npt.NDArray[float] + +logger = logging.get_logger(__name__) + + +def calc_ras_centroids_from_seg(seg_img: nibabelImage, label_ids: list[int] | None = None) \ + -> dict[int, np.ndarray | None]: + """Get centroids of segmentation labels in RAS coordinates, accepts any affine/data layout. + + Parameters + ---------- + seg_img : nibabel.analyze.SpatialImage + Input segmentation image. + label_ids : list[int], optional + List of label IDs to extract centroids for. If None, extracts all non-zero labels. + + Returns + ------- + dict[int, np.ndarray | None] + A dict mapping label IDs to their centroids (x,y,z) in RAS coordinates, None if label did not exist. + """ + # Get segmentation data and affine + seg_data: npt.NDArray[np.integer] = np.asarray(seg_img.dataobj) + vox2ras: AffineMatrix4x4 = seg_img.affine + + # Get unique labels + if label_ids is None: + labels = np.unique(seg_data) + labels = labels[labels > 0] # Exclude background + else: + labels = label_ids + + def _each_label(label): + # Get voxel indices for this label + if np.any(mask := seg_data == label): + # Calculate centroid in voxel space + vox_centroid = np.mean(np.where(mask), axis=1, dtype=float) + + # Convert to homogeneous coordinates + vox_centroid_hom = np.append(vox_centroid, 1) + + # Transform to RAS coordinates and return without homogeneous coordinate + return int(label), (vox2ras @ vox_centroid_hom)[:3] + else: + return int(label), None + + return dict(thread_executor().map(_each_label, labels)) + + +def convert_numpy_to_json_serializable(obj: object) -> object: + """Convert numpy types to JSON serializable types. + + Parameters + ---------- + obj : dict, list, array, number, serializable + Object to convert to JSON serializable type. + + Returns + ------- + object + JSON serializable version of the input object. + """ + if isinstance(obj, dict): + return {k: convert_numpy_to_json_serializable(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [convert_numpy_to_json_serializable(item) for item in obj] + elif isinstance(obj, np.ndarray): + return obj.tolist() + elif isinstance(obj, (np.integer, np.floating)): + # Handle numpy scalar types + return obj.item() + else: + return obj + + +def load_fsaverage_centroids(centroids_path: str | Path) -> dict[int, npt.NDArray[float]]: + """Load fsaverage centroids from static JSON file. + + Parameters + ---------- + centroids_path : str or Path + Path to the JSON file containing centroids. + + Returns + ------- + dict[int, np.ndarray] + Dictionary mapping label IDs to their centroids in RAS coordinates. + """ + + centroids_path = Path(centroids_path) + if not centroids_path.exists(): + raise FileNotFoundError(f"Fsaverage centroids file not found: {centroids_path}") + + with open(centroids_path) as f: + centroids_data = json.load(f) + + # Convert string keys back to integers and lists back to numpy arrays + return {int(label): np.array(centroid) for label, centroid in centroids_data.items()} + + +def load_fsaverage_affine(affine_path: str | Path) -> npt.NDArray[float]: + """Load fsaverage affine matrix from static text file. + + Parameters + ---------- + affine_path : str or Path + Path to the text file containing affine matrix. + + Returns + ------- + np.ndarray + 4x4 affine transformation matrix. + """ + + affine_path = Path(affine_path) + if not affine_path.exists(): + raise FileNotFoundError(f"Fsaverage affine file not found: {affine_path}") + + affine_matrix = np.loadtxt(affine_path).astype(float) + + if affine_matrix.shape != (4, 4): + raise ValueError(f"Expected 4x4 affine matrix, got shape {affine_matrix.shape}") + + return affine_matrix + + +def load_fsaverage_data(data_path: str | Path) -> tuple[AffineMatrix4x4, FSAverageHeader, AffineMatrix4x4]: + """Load fsaverage affine matrix and header fields from static JSON file. + + Parameters + ---------- + data_path : str or Path + Path to the JSON file containing combined data. + + Returns + ------- + affine_matrix : AffineMatrix4x4 + 4x4 affine transformation matrix. + header_fields : dict + Header fields needed for LTA: + - dims : list[int] + Volume dimensions [x,y,z]. + - delta : list[float] + Voxel size in mm [x,y,z]. + - Mdc : np.ndarray + 3x3 direction cosines matrix. + - Pxyz_c : np.ndarray + RAS center coordinates [x,y,z]. + vox2ras_tkr : AffineMatrix4x4 + Voxel to RAS tkr-space transformation matrix. + + Raises + ------ + FileNotFoundError + If the data file doesn't exist. + json.JSONDecodeError + If the file is not valid JSON. + ValueError + If required fields are missing. + """ + data_path = Path(data_path) + if not data_path.exists(): + raise FileNotFoundError(f"Fsaverage data file not found: {data_path}") + + with open(data_path) as f: + data = json.load(f) + + # Verify required fields + if "affine" not in data: + raise ValueError("Required field 'affine' missing from data file") + if "header" not in data: + raise ValueError("Required field 'header' missing from data file") + + required_header_fields = ["dims", "delta", "Mdc", "Pxyz_c"] + for field in required_header_fields: + if field not in data["header"]: + raise ValueError(f"Required header field missing: {field}") + + # Convert lists back to numpy arrays + affine_matrix = np.array(data["affine"]) + vox2ras_tkr = np.array(data["vox2ras_tkr"]) + header_data = FSAverageHeader( + dims=data["header"]["dims"], + delta=data["header"]["delta"], + Mdc=np.array(data["header"]["Mdc"]), + Pxyz_c=np.array(data["header"]["Pxyz_c"]), + ) + + # Validate affine matrix shape + if affine_matrix.shape != (4, 4): + raise ValueError(f"Expected 4x4 affine matrix, got shape {affine_matrix.shape}") + + return affine_matrix, header_data, vox2ras_tkr diff --git a/CorpusCallosum/fastsurfer_cc.py b/CorpusCallosum/fastsurfer_cc.py new file mode 100644 index 000000000..a0fbb4fa4 --- /dev/null +++ b/CorpusCallosum/fastsurfer_cc.py @@ -0,0 +1,1032 @@ +#!/usr/bin/env python3 +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +from collections.abc import Iterable +from functools import partial +from pathlib import Path +from time import perf_counter_ns +from typing import Literal, TypeVar, cast + +import nibabel as nib +import numpy as np +import torch +from monai.networks.nets import DenseNet +from scipy.ndimage import affine_transform + +from CorpusCallosum.data.constants import ( + CC_LABEL, + DEFAULT_INPUT_PATHS, + DEFAULT_OUTPUT_PATHS, + FSAVERAGE_CENTROIDS_PATH, + FSAVERAGE_DATA_PATH, + FSAVERAGE_MIDDLE, + THIRD_VENTRICLE_LABEL, +) +from CorpusCallosum.data.read_write import ( + FSAverageHeader, + calc_ras_centroids_from_seg, + convert_numpy_to_json_serializable, + load_fsaverage_centroids, + load_fsaverage_data, +) +from CorpusCallosum.localization import inference as localization_inference +from CorpusCallosum.segmentation import inference as segmentation_inference +from CorpusCallosum.segmentation import segmentation_postprocessing +from CorpusCallosum.shape.postprocessing import ( + check_area_changes, + create_sag_slice_vox2vox, + make_subdivision_mask, + recon_cc_surf_measures_multi, +) +from CorpusCallosum.utils.mapping_helpers import ( + apply_transform_to_pt, + apply_transform_to_volume, + calc_mapping_to_standard_space, + map_softlabels_to_orig, +) +from CorpusCallosum.utils.types import CCMeasuresDict, SliceSelection, SubdivisionMethod +from FastSurferCNN.data_loader.conform import conform, is_conform +from FastSurferCNN.segstats import HelpFormatter +from FastSurferCNN.utils import AffineMatrix4x4, Image3d, Mask3d, Shape3d, Vector2d, logging, nibabelImage, Image4d +from FastSurferCNN.utils.arg_types import path_or_none +from FastSurferCNN.utils.common import SubjectDirectory, find_device +from FastSurferCNN.utils.lta import write_lta +from FastSurferCNN.utils.parallel import shutdown_executors, thread_executor, get_num_threads, serial_executor +from FastSurferCNN.utils.parser_defaults import modify_argument +from recon_surf.align_points import find_rigid + +logger = logging.get_logger(__name__) + +_TPathLike = TypeVar("_TPathLike", str, Path, Literal[None]) + + +class ArgumentDefaultsHelpFormatter(HelpFormatter): + """Help message formatter which adds default values to argument help.""" + + def _get_help_string(self, action): + """ + Add the default value to the option help message. + """ + help = action.help + if help is None: + help = '' + + if "%(default)" not in help and not getattr(action, "required", False): + if action.default is not argparse.SUPPRESS and not getattr(action.default, "DO_NOT_PRINT_DEFAULT", False): + defaulting_nargs = [argparse.OPTIONAL, argparse.ZERO_OR_MORE] + if action.option_strings or action.nargs in defaulting_nargs: + help += " (not used by default)" if action.default is None else " (default: %(default)s)" + return help + + +class _FixFloatFormattingList(list): + def __init__(self, items: Iterable, item_format_spec: str): + self._format_spec = item_format_spec + super().__init__(items) + + def __str__(self): + return "[" + ", ".join(map(lambda x: format(x, self._format_spec), self)) + "]" + + +def _do_not_print(value): + class _DoNotPrintGeneric(type(value)): + DO_NOT_PRINT_DEFAULT = True + + return _DoNotPrintGeneric(value) + + +def make_parser() -> argparse.ArgumentParser: + """Create the argument parse object for the pipeline.""" + from FastSurferCNN.utils.parser_defaults import add_arguments + + parser = argparse.ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) + + parser.add_argument( + "-v", + "--verbose", + action="count", + default=_do_not_print(0), + help="Enable verbose (pass twice for debug-output).", + ) + # Specify subject directory + subject ID, OR specify individual MRI and segmentation files + output paths + add_arguments(parser, ["sd", "sid", "conformed_name", "aseg_name", "device"]) + + def _set_help_sid(action): + action.help = "The subject id to use." + modify_argument(parser, "--sid", _set_help_sid) + + parser.add_argument( + "--num_thickness_points", + type=int, + default=100, + help="Number of points for thickness estimation." + ) + parser.add_argument( + "--subdivisions", + type=float, + metavar="FRAC", + default=_FixFloatFormattingList([1 / 6, 1 / 2, 2 / 3, 3 / 4], ".3f"), + help="List of subdivision fractions for the corpus callosum subsegmentation." + "The method allows for an arbitrary number of fractions." + "By default it uses following Hofer-Frahms convention." + ) + parser.add_argument( + "--subdivision_method", + default=_do_not_print("shape"), + help="Method for contour subdivision. Options:
" + "- shape (default): Intercallosal subdivision perpendicular to intercallosal line,
" + "- vertical: orthogonal to the most anterior and posterior points in the AC/PC standardized CC contour, " + "
" + "- angular: subdivision based on equally spaced angles, as proposed by Hampel and colleagues,
" + "- eigenvector: primary direction, same as FreeSurfers mri_cc.", + choices=["shape", "vertical", "angular", "eigenvector"], + ) + parser.add_argument( + "--contour_smoothing", + type=float, + default=5, + help="Gaussian sigma for smoothing during contour detection. Higher values mean a smoother CC outline, at the " + "cost of precision.", + ) + def _slice_selection(a: str) -> SliceSelection: + if (b := a.lower()) in ("middle", "all"): + return b + return int(a) + parser.add_argument( + "--slice_selection", + type=_slice_selection, + default=_do_not_print("all"), + help="Which slices to process. Options: 'middle', 'all' (default), or a specific slice number.", + ) + + ######## OUTPUT PATHS ######### + # 4. Options for advanced, technical parameters + advanced = parser.add_argument_group( + title="Advanced options", + description="Custom output paths, useful if no standard case directory is used. Relative paths are always " + "relative to the subject_dir defined via --sd and --sid!", + ) + add_arguments(advanced, ["threads"]) + advanced.add_argument( + "--segmentation", "--seg", + type=path_or_none, + help="Output path for corpus callosum and fornix segmentation output.", + default=Path(DEFAULT_OUTPUT_PATHS["segmentation"]), + ) + advanced.add_argument( + "--segmentation_in_orig", + type=path_or_none, + help="Output path for corpus callosum and fornix segmentation output in the input MRI space.", + default=DEFAULT_OUTPUT_PATHS["segmentation_in_orig"], + ) + advanced.add_argument( + "--cc_measures", + type=path_or_none, + help="Output path for surface-based corpus callosum measures describing shape and volume for each image slice.", + default=Path(DEFAULT_OUTPUT_PATHS["cc_measures"]), + ) + advanced.add_argument( + "--cc_mid_measures", + type=path_or_none, + help="Output path for surface-based corpus callosum measures of the midslice describing CC shape and volume.", + default=DEFAULT_OUTPUT_PATHS["cc_markers"], + ) + advanced.add_argument( + "--upright_lta", + type=path_or_none, + help="Output path for upright LTA transform. This makes sure the midplane is at 128 in LR direction, " + "but no nodding correction is applied.", + default=DEFAULT_OUTPUT_PATHS["upright_lta"], + ) + advanced.add_argument( + "--upright_volume", + type=path_or_none, + help="Output path for upright volume (input image with cc_up.lta applied).", + default=None, + ) + advanced.add_argument( + "--orient_volume_lta", + type=path_or_none, + help="Output path for orientation volume LTA transform. This makes sure the midplane is the volume center, " + "the anterior and posterior commisures are on the coordinate line, and the posterior commissure is " + "at the origin - standardizing the head position.", + default=DEFAULT_OUTPUT_PATHS["orient_volume_lta"], + ) + advanced.add_argument( + "--qc_image", + type=path_or_none, + help="Output path for QC visualization image.", + default=DEFAULT_OUTPUT_PATHS["qc_image"], + ) + advanced.add_argument( + "--save_template_dir", + type=path_or_none, + help="Directory path where to save contours.txt and thickness_values.txt files. These files can be used to " + "visualize the CC shape and volume with the cc_visualization.py script.", + default=None, + ) + advanced.add_argument( + "--thickness_image", + type=path_or_none, + help="Output path for thickness image.", + default=DEFAULT_OUTPUT_PATHS["thickness_image"], + ) + advanced.add_argument( + "--surf", + dest="cc_surf", + type=path_or_none, + help="Output path for surf file.", + default=DEFAULT_OUTPUT_PATHS["cc_surf"], + ) + advanced.add_argument( + "--thickness_overlay", + type=path_or_none, + help="Output path for corpus callosum thickness overlay file.", + default=DEFAULT_OUTPUT_PATHS["cc_thickness_overlay"], + ) + advanced.add_argument( + "--cc_interactive_html", "--cc_html", + dest="cc_html", + type=path_or_none, + help="Output path to the corpus callosum interactive 3D visualization HTML file.", + default=DEFAULT_OUTPUT_PATHS["cc_html"], + ) + advanced.add_argument( + "--cc_surf_vtk", + type=path_or_none, + help=f"Output path for vtk file, showing the CC 3D mesh. Example: {DEFAULT_OUTPUT_PATHS['cc_surf_vtk']}.", + default=None, + ) + advanced.add_argument( + "--softlabels_cc", + type=path_or_none, + help=f"Output path for corpus callosum softlabels, which contains the soft labels of each voxel. " + f"Example: {DEFAULT_OUTPUT_PATHS['softlabels_cc']}.", + default=None, + ) + advanced.add_argument( + "--softlabels_fn", + type=path_or_none, + help=f"Output path for fornix softlabels, which contains the soft labels of each voxel. " + f"Example: {DEFAULT_OUTPUT_PATHS['softlabels_fn']}.", + default=None, + ) + advanced.add_argument( + "--softlabels_background", + type=path_or_none, + help=f"Output path for background softlabels, which contains the probability of each voxel. " + f"Example: {DEFAULT_OUTPUT_PATHS['softlabels_background']}.", + default=None, + ) + ############ END OF OUTPUT PATHS ############ + return parser + + +def options_parse() -> argparse.Namespace: + """Parse command line arguments for the pipeline.""" + parser = make_parser() + args = parser.parse_args() + + # Reconstruct subject_dir from sd and sid (but sd might be stored as out_dir by parser_defaults) + sd_value = getattr(args, 'out_dir', None) + if sd_value and hasattr(args, 'sid') and args.sid: + args.subject_dir = Path(sd_value) / args.sid + else: + args.subject_dir = None + + # Validation logic: must use either directory approach (--sd + --sid) OR file approach (--conf_name + --aseg_name) + if sd_value: + # Using directory approach - make sure sid was also provided + if not (hasattr(args, 'sid') and args.sid): + parser.error("When using --sd, you must also provide --sid.") + elif hasattr(args, 'sid') and args.sid: + # If sid is provided without sd, that's an error + if not sd_value: + parser.error("When using --sid, you must also provide --sd.") + elif hasattr(args, 'conf_name') and args.conf_name: + # Using file approach - make sure aseg_name was also provided + if not (hasattr(args, 'aseg_name') and args.aseg_name): + parser.error("When using --conf_name, you must also provide --aseg_name.") + elif hasattr(args, 'aseg_name') and args.aseg_name: + # If aseg_name is provided without conf_name, that's an error + if not (hasattr(args, 'conf_name') and args.conf_name): + parser.error("When using --aseg_name, you must also provide --conf_name.") + else: + parser.error("You must specify either --sd and --sid OR both --conf_name and --aseg_name.") + + # If subject_dir is provided, set default paths for missing arguments + if args.subject_dir: + # Create standard FreeSurfer subdirectories + if not args.conf_name: + args.conf_name = args.subject_dir / DEFAULT_INPUT_PATHS["conf_name"] + + if not args.aseg_name: + args.aseg_name = args.subject_dir / DEFAULT_INPUT_PATHS["aseg_name"] + else: + print("WARNING: Not providing subject_dir leads to discarding of files with relative paths!") + args.subject_dir = None + for arg, path in (("--aseg_name", args.aseg_name), ("--conformed_name", args.conf_name)): + if path is None or not Path(path).is_absolute(): + parser.error( + f"When not passing --sd , arguments of --aseg_name and --conformed_name must be " + f"absolute! But the argument passed to {arg} was {path}, i.e. not absolute." + ) + + all_paths = ("segmentation", "segmentation_in_orig", "cc_measures", "upright_lta", "orient_volume_lta", + "cc_surf", "softlabels_cc", "softlabels_fn", "softlabels_background", "cc_mid_measures", + "thickness_overlay", "qc_image", "thickness_image", "cc_html") + + warnings_paths = [] + # Create parent directories for all output paths + for path_name in all_paths: + path: Path | None = getattr(args, path_name, None) + if isinstance(path, Path) and not args.subject_dir and not path.is_absolute(): + # set path to none in arguments + warnings_paths.append(path_name) + setattr(args, path_name, None) + if warnings_paths: + _warnings_paths = "' '".join(warnings_paths) + print(f"WARNING: Not writing '{_warnings_paths}', because --sd and --sid are not specified and " + f"its paths are relative.") + return args + + +def register_centroids_to_fsavg(aseg_nib: nibabelImage) \ + -> tuple[AffineMatrix4x4, AffineMatrix4x4, AffineMatrix4x4, FSAverageHeader, AffineMatrix4x4]: + """Perform centroid-based registration between subject and fsaverage space. + + Computes a rigid transformation between the subject's segmentation and fsaverage space + by aligning centroids of corresponding anatomical structures. + + Parameters + ---------- + aseg_nib : nibabel.analyze.SpatialImage + Subject's segmentation image. + + Returns + ------- + aseg2fsaverage_vox2vox : AffineMatrix4x4 + Transformation matrix from original to fsaverage voxel space. + aseg2fsaverage_ras2ras : AffineMatrix4x4 + Transformation matrix from original to fsaverage RAS space. + fsaverage_hires_vox2ras : AffineMatrix4x4 + High-resolution fsaverage affine matrix. + fsaverage_header : FSAverageHeader + FSAverage header fields for LTA writing. + fsaverage_vox2ras_tkr : AffineMatrix4x4 + Voxel to RAS tkr-space transformation matrix. + + Notes + ----- + The function uses pre-computed fsaverage centroids and data from static files + to perform the registration. It matches corresponding anatomical structures + between the subject's segmentation and fsaverage space. + """ + logger.info("Starting centroid registration") + + # Load pre-computed fsaverage centroids and data from static files + fsaverage_data_future = thread_executor().submit(load_fsaverage_data, FSAVERAGE_DATA_PATH) + ras_centroids_dst = load_fsaverage_centroids(FSAVERAGE_CENTROIDS_PATH) + + ras_centroids_mov = calc_ras_centroids_from_seg(aseg_nib, label_ids=list(ras_centroids_dst.keys())) + + # get the set of joint labels + joint_centroid_labels = [lbl for lbl, v in ras_centroids_mov.items() if v is not None] + + ras_centroids_mov = np.array([ras_centroids_mov[lbl] for lbl in joint_centroid_labels]).T + ras_centroids_dst = np.array([ras_centroids_dst[lbl] for lbl in joint_centroid_labels]).T + + aseg2fsaverage_ras2ras: AffineMatrix4x4 = find_rigid(p_mov=ras_centroids_mov.T, p_dst=ras_centroids_dst.T) + + # make affine that increases resolution to orig resolution + aseg_zooms = list(nib.as_closest_canonical(aseg_nib).header.get_zooms()[:3]) + resolution_trans: AffineMatrix4x4 = np.diagflat([aseg_zooms[0], aseg_zooms[2], aseg_zooms[1], 1]).astype(float) + + fsaverage_vox2ras, fsavg_header, vox2ras_tkr = fsaverage_data_future.result() + fsavg_header["delta"] = np.asarray([aseg_zooms[0], aseg_zooms[2], aseg_zooms[1]]) # vox sizes in lia + # fsavg_hires_vox2ras translation should be 128 always (independent of resolution) + fsavg_hires_vox2ras: AffineMatrix4x4 = np.concatenate( + [(resolution_trans @ fsaverage_vox2ras)[:, :3], fsaverage_vox2ras[:, 3:4]], + axis=1, + ) + fsavg_header["dims"] = np.ceil(fsavg_header["dims"] @ np.linalg.inv(resolution_trans[:3, :3])).astype(int).tolist() + + aseg2fsavg_vox2vox: AffineMatrix4x4 = np.linalg.inv(fsavg_hires_vox2ras) @ aseg2fsaverage_ras2ras @ aseg_nib.affine + logger.info("Centroid registration successful!") + return aseg2fsavg_vox2vox, aseg2fsaverage_ras2ras, fsavg_hires_vox2ras, fsavg_header, vox2ras_tkr + + +def localize_ac_pc( + orig_data: Image3d, + aseg_nib: nibabelImage, + orig2midslice_vox2vox: AffineMatrix4x4, + model_localization: DenseNet, + resample_shape: Shape3d, +) -> tuple[Vector2d, Vector2d]: + """Localize anterior and posterior commissure points in the brain. + + Uses a trained model to detect AC and PC points in mid-sagittal slices, + using the third ventricle as an anatomical reference. + + Parameters + ---------- + orig_data : np.ndarray + Array of intensity data. + aseg_nib : nibabelImage + Subject's segmentation image in native subject space. + orig2midslice_vox2vox : np.ndarray + Transformation matrix from subject/native space to fsaverage space (in lia). + model_localization : DenseNet + Trained model for AC-PC detection. + resample_shape : 3-tuple of ints + Number of slices to process. + + Returns + ------- + ac_coords : np.ndarray + Coordinates of the anterior commissure. + pc_coords : np.ndarray + Coordinates of the posterior commissure. + """ + num_slices_to_analyze = resample_shape[0] + resample_shape = (num_slices_to_analyze + 2,) + resample_shape[1:] # 2 for context slices + _midslices_fut = thread_executor().submit( + affine_transform, + orig_data, + np.linalg.inv(orig2midslice_vox2vox), # inverse is required for affine_transform + output_shape=resample_shape, + order=2, # unclear, why this is not order=3 + mode="constant", + cval=0, + prefilter=True, # unclear, why we are using a smoothing filter here + ) + + # get center of third ventricle from aseg and map to fsaverage space (voxel coordinates) + third_ventricle_mask = np.asarray(aseg_nib.dataobj) == THIRD_VENTRICLE_LABEL + third_ventricle_center = np.argwhere(third_ventricle_mask).mean(axis=0) + third_ventricle_center_vox = apply_transform_to_pt(third_ventricle_center, orig2midslice_vox2vox, inv=False) + + # get 5 mm of slices with 3 slices per inference (cropping num_slices_to_analyze + 2 slices around the center) + ac_coords, pc_coords = localization_inference.run_inference_on_slice( + model_localization, _midslices_fut.result(), third_ventricle_center_vox[1:], + ) + + return ac_coords, pc_coords + + +def segment_cc( + midslices: Image3d, + ac_coords: Vector2d, + pc_coords: Vector2d, + aseg_nib: nibabelImage, + model_segmentation: "torch.nn.Module", +) -> tuple[Mask3d, Image4d]: + """Segment the corpus callosum using a trained model. + + Performs corpus callosum segmentation on mid-sagittal slices using a trained model, with AC-PC points as anatomical + references. Includes post-processing to clean the cc_seg_labels. + + Parameters + ---------- + midslices : np.ndarray + Array of mid-sagittal slices. + ac_coords : np.ndarray + Anterior commissure coordinates. + pc_coords : np.ndarray + Posterior commissure coordinates. + aseg_nib : nibabelImage + Subject's cc_seg_labels image. + model_segmentation : torch.nn.Module + Trained model for CC cc_seg_labels. + + Returns + ------- + cc_seg_labels : np.ndarray + Binary cc_seg_labels of the corpus callosum. + cc_softlabels : np.ndarray + Soft cc_seg_labels probabilities of shape (H, W, D, C=3). + """ + pre_clean_segmentation, inputs, cc_softlabels = segmentation_inference.run_inference_on_slice( + model_segmentation, + midslices, + ac_center=ac_coords, + pc_center=pc_coords, + voxel_size=nib.as_closest_canonical(aseg_nib).header.get_zooms()[2:0:-1], # convert from RAS to LIA + ) + + cc_seg_labels, cc_volume_mask = segmentation_postprocessing.clean_cc_segmentation(pre_clean_segmentation) + + # print a warning if the cc_volume_mask touches the edge of the segmentation + if np.any(cc_volume_mask[:, [0, -1]]) or np.any(cc_volume_mask[:, :, [0, -1]]): + logger.warning("CC volume mask touches the edge of the cc_seg_labels field-of-view, CC might be truncated") + + # get voxels that were removed during cleaning + cleaned_mask = pre_clean_segmentation != cc_seg_labels + cc_softlabels[cleaned_mask, 1] = 0 + cc_softlabels[cleaned_mask, :] /= np.sum(cc_softlabels[cleaned_mask, :], axis=-1, keepdims=True) + 1e-6 + + return cc_seg_labels, cc_softlabels + + +def main( + conf_name: str | Path, + aseg_name: str | Path, + subject_dir: str | Path, + slice_selection: SliceSelection = "middle", + num_thickness_points: int = 100, + subdivisions: list[float] | None = None, + subdivision_method: SubdivisionMethod = "shape", + contour_smoothing: float = 5, + save_template_dir: str | Path | None = None, + device: str | torch.device = "auto", + upright_volume: str | Path | None = None, + segmentation: str | Path | None = None, + cc_measures: str | Path | None = None, + cc_mid_measures: str | Path | None = None, + upright_lta: str | Path | None = None, + orient_volume_lta: str | Path | None = None, + cc_surf: str | Path | None = None, + cc_thickness_overlay: str | Path | None = None, + cc_html: str | Path | None = None, + cc_surf_vtk: str | Path | None = None, + segmentation_in_orig: str | Path | None = None, + qc_image: str | Path | None = None, + thickness_image: str | Path | None = None, + softlabels_cc: str | Path | None = None, + softlabels_fn: str | Path | None = None, + softlabels_background: str | Path | None = None, +) -> None: + """Main pipeline function for corpus callosum analysis. + + This function performs the complete corpus callosum analysis pipeline including + registration, landmark detection, segmentation, and morphometry analysis. + + Parameters + ---------- + conf_name : str or Path + Path to input MRI file. + aseg_name : str or Path + Path to input segmentation file. + subject_dir : str or Path + FastSurfer/FreeSurfer subject directory and directory for output files. + slice_selection : "middle", "all" or int, default="middle" + Which slices to process. + num_thickness_points : int, default=100 + Number of points for thickness estimation. + subdivisions : list[float], optional + List of subdivision fractions for CC subsegmentation. + subdivision_method : any of "shape", "vertical", "angular", "eigenvector", default="shape" + Method for contour subdivision. + contour_smoothing : float, default=5 + Gaussian sigma for smoothing during contour detection. + save_template_dir : str or Path, optional + Directory path where to save contours.txt and thickness_values.txt files. These files can be used to visualize + the CC shape and volume in 3D. Files are only saved, if a valid directory path is passed. + device : str, default="auto" + Device to run inference on ('auto', 'cpu', 'cuda', or 'cuda:X'). + upright_volume : str or Path, optional + Path to save upright volume. + segmentation : str or Path, optional + Path to save segmentation. + cc_measures : str or Path, optional + Path to save post-processing results. + cc_mid_measures : str or Path, optional + Path to save CC markers. + upright_lta : str or Path, optional + Path to save upright LTA transform. + orient_volume_lta : str or Path, optional + Path to save orientation transform. + cc_surf : str or Path, optional + Path to save surface file. + cc_thickness_overlay : str or Path, optional + Path to save overlay file. + cc_html : str or Path, optional + Path to save HTML visualization. + cc_surf_vtk : str or Path, optional + Path to save VTK file. + segmentation_in_orig : str or Path, optional + Path to save segmentation in original space. + qc_image : str or Path, optional + Path to save QC images. + thickness_image : str or Path, optional + Path to save thickness visualization. + softlabels_cc : str or Path, optional + Path to save CC soft labels. + softlabels_fn : str or Path, optional + Path to save fornix soft labels. + softlabels_background : str or Path, optional + Path to save background soft labels. + + Notes + ----- + The function saves multiple outputs to specified paths or default locations in output_dir: + - cc_markers.json: Contains detected landmarks and measurements. + - midplane_slices.mgz: Extracted midplane slices. + - upright_volume.mgz: Volume aligned to standard orientation. + - segmentation.mgz: Corpus callosum segmentation. + - cc_postproc_results.json: Enhanced postprocessing results. + - Various visualization plots and transformation matrices. + + The pipeline consists of the following steps: + 1. Initializes environment and loads models. + 2. Registers input image to fsaverage space. + 3. Detects AC and PC points. + 4. Segments the corpus callosum. + 5. Performs enhanced post-processing analysis. + 6. Saves results and visualizations. + """ + start = perf_counter_ns() + + import sys + + if subdivisions is None: + subdivisions = [1 / 6, 1 / 2, 2 / 3, 3 / 4] + + subject_dir = Path("/dev/null/no-subject-dir" if subject_dir is None else subject_dir) + + logger.info("Starting corpus callosum analysis pipeline") + logger.info(f"Input MRI: {conf_name}") + logger.info(f"Input segmentation: {aseg_name}") + logger.info(f"Output directory: {subject_dir}") + + # Convert all paths to Path objects + sd = SubjectDirectory( + subject_dir.parent, + id=subject_dir.name, + conf_name=conf_name, + aseg_name=aseg_name, + save_template_dir=save_template_dir, + upright_volume=upright_volume, + cc_segmentation=segmentation, + cc_measures=cc_measures, + cc_mid_measures=cc_mid_measures, + upright_lta=upright_lta, + cc_orient_volume_lta=orient_volume_lta, + cc_surf=cc_surf, + cc_thickness_overlay=cc_thickness_overlay, + cc_html=cc_html, + cc_mesh=cc_surf_vtk, + cc_orig_segfile=segmentation_in_orig, + cc_qc_image=qc_image, + cc_thickness_image=thickness_image, + cc_softlabels_cc=softlabels_cc, + cc_softlabels_fn=softlabels_fn, + cc_softlabels_background=softlabels_background, + ) + + # Validate subdivision fractions + if any(i < 0 or i > 1 for i in subdivisions): + logger.error(f"Subdivision fractions must be between 0 and 1, but got: {subdivisions}") + sys.exit(1) + + #### setup variables + io_futures = [] + + # load models + device = find_device(device) + logger.info(f"Using device: {device}") + + logger.info("Loading models") + _model_localization = thread_executor().submit(localization_inference.load_model, device=device) + _model_segmentation = thread_executor().submit(segmentation_inference.load_model, device=device) + + _aseg_fut = thread_executor().submit(nib.load, sd.filename_by_attribute("aseg_name")) + orig = cast(nibabelImage, nib.load(sd.conf_name)) + + # check that the image is conformed, i.e. isotropic 1mm voxels, 256^3 size, LIA orientation + if not is_conform(orig, vox_size=None, img_size=None, orientation=None): + logger.info("Internally conforming orig to soft-LIA.") + orig = conform(orig, vox_size=None, img_size=None, orientation=None) + + # 5 mm around the midplane (guaranteed to be aligned RAS by as_closest_canonical) + vox_size_ras: tuple[float, float, float] = nib.as_closest_canonical(orig).header.get_zooms() + vox_size = vox_size_ras[0], vox_size_ras[2], vox_size_ras[1] # convert from RAS to LIA + slices_to_analyze = int(np.ceil(5 / vox_size[0])) + # slices_to_analyze must be odd + if slices_to_analyze % 2 == 0: + slices_to_analyze += 1 + + logger.info( + f"Segmenting {slices_to_analyze} slices (5 mm width at {vox_size[0]} mm resolution, " + "center around the mid-sagittal plane)" + ) + + aseg_img = cast(nibabelImage, _aseg_fut.result()) + + if not np.allclose(aseg_img.affine, orig.affine): + logger.error("Input MRI and segmentation are not aligned! Please check your input files.") + sys.exit(1) + + logger.info("Performing centroid registration to fsaverage space") + orig2fsavg_vox2vox, orig2fsavg_ras2ras, fsavg_vox2ras, fsavg_header, _ = ( + register_centroids_to_fsavg(aseg_img) + ) + + # start saving upright volume, this is the image in fsaverage space but not yet oriented via AC-PC + if sd.has_attribute("upright_volume"): + # upright == fsaverage-aligned + io_futures.append( + thread_executor().submit( + apply_transform_to_volume, + orig, + orig2fsavg_vox2vox, + fsavg_vox2ras, + output_path=sd.filename_by_attribute("upright_volume"), + output_size=fsavg_header["dims"], + ) + ) + + # calculate affine for segmentation volume + affine_x_offset = partial(create_sag_slice_vox2vox, fsaverage_middle=FSAVERAGE_MIDDLE / vox_size[0]) + fsavg2midslab_in_vox2vox: AffineMatrix4x4 = affine_x_offset(slices_to_analyze // 2) + # first, midslice->fsaverage in vox2vox, then vox2ras in fsaverage space + fsaverage_midslab_vox2ras: AffineMatrix4x4 = fsavg_vox2ras @ np.linalg.inv(fsavg2midslab_in_vox2vox) + + # calculate vox2vox for input resampling volumes + def _orig2midslab_vox2vox(additional_context: int = 0) -> AffineMatrix4x4: + fsavg2midslab = affine_x_offset(slices_to_analyze // 2 + additional_context // 2) + # first, orig->fsaverage in vox2vox, then fsaverage->midslab in vox2vox + return fsavg2midslab @ orig2fsavg_vox2vox + + #### do localization and segmentation inference + logger.info("Starting AC/PC localization") + target_shape: tuple[int, int, int] = (slices_to_analyze, fsavg_header["dims"][1], fsavg_header["dims"][2]) + # predict ac and pc coordinates in upright AS space + ac_coords, pc_coords = localize_ac_pc( + np.asarray(orig.dataobj), + aseg_img, + _orig2midslab_vox2vox(additional_context=2), + _model_localization.result(), + target_shape, + ) + logger.info("Starting corpus callosum segmentation") + extra_slices = 8 # 8 extra in x-direction for context slices + target_shape: Shape3d = (slices_to_analyze + extra_slices, fsavg_header["dims"][1], fsavg_header["dims"][2]) + midslices: Image3d = affine_transform( + np.asarray(orig.dataobj), + np.linalg.inv(_orig2midslab_vox2vox(additional_context=extra_slices)), # inverse is required for affine_transform + output_shape=target_shape, + order=2, # @ClePol unclear, why this is not order=3 + mode="constant", + cval=0, + prefilter=True, # unclear, why we are using a smoothing filter here + ) + cc_fn_seg_labels, cc_fn_softlabels = segment_cc( + midslices, + ac_coords, + pc_coords, + aseg_img, + _model_segmentation.result(), + ) + + # save segmentation softlabels + for i, (attr, name) in enumerate((("background",) * 2, ("cc", "Corpus Callosum"), ("fn", "Fornix"))): + if sd.has_attribute(f"cc_softlabels_{attr}"): + logger.info(f"Saving {name} softlabels to {sd.filename_by_attribute(f'cc_softlabels_{attr}')}") + io_futures.append(thread_executor().submit( + nib.save, + nib.MGHImage(cc_fn_softlabels[..., i], fsaverage_midslab_vox2ras, orig.header), + sd.filename_by_attribute(f"cc_softlabels_{attr}"), + )) + + # Create a temporary segmentation image with proper affine for enhanced postprocessing + # Process slices based on selection mode + + logger.info(f"Processing slices with selection mode: {slice_selection}") + slice_results, slice_io_futures = recon_cc_surf_measures_multi( + segmentation=cc_fn_seg_labels, + slice_selection=slice_selection, + fsavg_vox2ras=fsavg_vox2ras, + midslices=midslices, + ac_coords=ac_coords, + pc_coords=pc_coords, + num_thickness_points=num_thickness_points, + subdivisions=subdivisions, + subdivision_method=subdivision_method, + contour_smoothing=contour_smoothing, + vox_size=vox_size, + subject_dir=sd, + ) + io_futures.extend(slice_io_futures) + + outer_contours = [slice_result["split_contours"][0] for slice_result in slice_results] + + if len(outer_contours) > 1 and not check_area_changes(outer_contours): + logger.warning( + "Large area changes detected between consecutive slices, this is likely due to a segmentation error." + ) + + # Get middle slice result + middle_slice_result: CCMeasuresDict = slice_results[len(slice_results) // 2] + if len(middle_slice_result["split_contours"]) <= 5: + cc_subseg_midslice = make_subdivision_mask( + (cc_fn_seg_labels.shape[1], cc_fn_seg_labels.shape[2]), + middle_slice_result["split_contours"], + vox_size[1:3], + ) + else: + logger.warning("Too many subsegments for lookup table, skipping sub-division of output segmentation.") + cc_subseg_midslice = None + + # save segmentation labels, this + if sd.has_attribute("cc_segmentation"): + io_futures.append(thread_executor().submit( + nib.save, + nib.MGHImage(cc_fn_seg_labels, fsaverage_midslab_vox2ras, orig.header), + sd.filename_by_attribute("cc_segmentation"), + )) + # map soft labels to original space (in parallel because this takes a while, and we only do it to save the labels) + if sd.has_attribute("cc_orig_segfile"): + # if num_threads is not large enough (>1), this might be blocking ; serial_executor runs the function in submit + executor = thread_executor() if get_num_threads() > 2 else serial_executor() + io_futures.append(executor.submit( + map_softlabels_to_orig, + cc_fn_softlabels=cc_fn_softlabels, + orig=orig, + orig_space_segmentation_path=sd.filename_by_attribute("cc_orig_segfile"), + orig2slab_vox2vox=_orig2midslab_vox2vox(), + cc_subseg_midslice=cc_subseg_midslice, + orig2midslice_vox2vox=affine_x_offset(0) @ orig2fsavg_vox2vox, # orig2fsavg, then full2midslice + )) + + METRICS = [ + "areas", + "thickness", + "curvature", + "midline_length", + "circularity", + "cc_index", + "total_area", + "total_perimeter", + "thickness_profile", + ] + + # Record key metrics for middle slice + output_metrics_middle_slice = {metric: middle_slice_result[metric] for metric in METRICS} + + # Create enhanced output dictionary with all slice results + per_slice_output_dict = { + "slices": [convert_numpy_to_json_serializable({metric: result[metric] for metric in METRICS}) + for result in slice_results], + } + + ########## Save outputs ########## + additional_metrics = {} + if len(outer_contours) > 1: + cc_volume_voxel = segmentation_postprocessing.get_cc_volume_voxel( + desired_width_mm=5, + cc_mask=np.equal(cc_fn_seg_labels, CC_LABEL), + voxel_size=vox_size, # in LIA order + ) + logger.info(f"CC volume voxel: {cc_volume_voxel}") + # FIXME: Create a proper mesh and use cc_mesh.volume for this volume + try: + cc_volume_contour = segmentation_postprocessing.get_cc_volume_contour( + cc_contours=outer_contours, + voxel_size=vox_size, # in LIA order + ) + logger.info(f"CC volume contour: {cc_volume_contour}") + except AssertionError as e: + logger.warning("Could not compute CC volume from contours, setting to NaN") + logger.exception(e) + cc_volume_contour = float('nan') + + additional_metrics["cc_5mm_volume"] = cc_volume_voxel + additional_metrics["cc_5mm_volume_pv_corrected"] = cc_volume_contour + + # get ac and pc in all spaces + ac_coords_3d = np.hstack((FSAVERAGE_MIDDLE, ac_coords)) + pc_coords_3d = np.hstack((FSAVERAGE_MIDDLE, pc_coords)) + standardized2orig_vox2vox, ac_coords_standardized, pc_coords_standardized, ac_coords_orig, pc_coords_orig = ( + calc_mapping_to_standard_space(orig, ac_coords_3d, pc_coords_3d, orig2fsavg_vox2vox) + ) + + # write output dict as csv + additional_metrics["ac_center"] = ac_coords_orig + additional_metrics["pc_center"] = pc_coords_orig + additional_metrics["ac_center_oriented_volume"] = ac_coords_standardized + additional_metrics["pc_center_oriented_volume"] = pc_coords_standardized + additional_metrics["ac_center_upright"] = ac_coords_3d + additional_metrics["pc_center_upright"] = pc_coords_3d + additional_metrics["slices_in_segmentation"] = slices_to_analyze + additional_metrics["voxel_size"] = np.asarray(orig.header.get_zooms(), dtype=float).tolist() + additional_metrics["num_thickness_points"] = num_thickness_points + additional_metrics["subdivision_method"] = subdivision_method + additional_metrics["subdivision_ratios"] = subdivisions + additional_metrics["contour_smoothing"] = contour_smoothing + additional_metrics["slice_selection"] = slice_selection + + + if sd.has_attribute("cc_mid_measures"): + io_futures.append(thread_executor().submit( + save_cc_measures_json, + sd.filename_by_attribute('cc_mid_measures'), + output_metrics_middle_slice | additional_metrics, + )) + + if sd.has_attribute("cc_measures"): + io_futures.append(thread_executor().submit( + save_cc_measures_json, + sd.filename_by_attribute("cc_measures"), + per_slice_output_dict | additional_metrics, + )) + + # save lta to fsaverage space + + if sd.has_attribute("upright_lta"): + sd.filename_by_attribute("cc_mid_measures").parent.mkdir(exist_ok=True, parents=True) + logger.info(f"Saving LTA to fsaverage space: {sd.filename_by_attribute('upright_lta')}") + io_futures.append(thread_executor().submit( + write_lta, + sd.filename_by_attribute("upright_lta"), + orig2fsavg_ras2ras, + sd.filename_by_attribute("aseg_name"), + aseg_img.header, + "fsaverage", + fsavg_header, + )) + + if sd.has_attribute("cc_orient_volume_lta"): + sd.filename_by_attribute("cc_orient_volume_lta").parent.mkdir(exist_ok=True, parents=True) + # save lta to standardized space (fsaverage + nodding + ac to center) + orig2standardized_ras2ras = orig.affine @ np.linalg.inv(standardized2orig_vox2vox) @ np.linalg.inv(orig.affine) + logger.info(f"Saving LTA to standardized space: {sd.filename_by_attribute('cc_orient_volume_lta')}") + io_futures.append(thread_executor().submit( + write_lta, + sd.filename_by_attribute("cc_orient_volume_lta"), + orig2standardized_ras2ras, + sd.conf_name, + orig.header, + sd.conf_name, + orig.header, + )) + + # this waits for all io to finish + for fut in io_futures: + e = fut.exception() + if e and isinstance(e, Exception): + logger.exception(e) + shutdown_executors() + + duration = (perf_counter_ns() - start) / 1e9 + logger.info(f"CorpusCallosum analysis pipeline completed successfully in {duration:.2f} seconds.") + + +def save_cc_measures_json(cc_mid_measure_file: Path, metrics: dict[str, object]): + """Save JSON metrics file.""" + # Convert numpy arrays to lists for JSON serialization + logger.info(f"Saving CC markers to {cc_mid_measure_file}") + cc_mid_measure_file.parent.mkdir(exist_ok=True, parents=True) + with open(cc_mid_measure_file, "w") as f: + json.dump(convert_numpy_to_json_serializable(metrics), f, indent=4) + + +if __name__ == "__main__": + options = options_parse() + + # Set up logging if verbose mode is enabled + logging.setup_logging(None, options.verbose) # Log to stdout only + + main( + conf_name=options.conf_name, + aseg_name=options.aseg_name, + subject_dir=options.subject_dir, + #FIXME: slice_selection is True/bool + slice_selection=options.slice_selection, + num_thickness_points=options.num_thickness_points, + subdivisions=list(options.subdivisions), + subdivision_method=str(options.subdivision_method), + contour_smoothing=options.contour_smoothing, + save_template_dir=options.save_template_dir, + device=options.device, + upright_volume=options.upright_volume, + segmentation=options.segmentation, + cc_measures=options.cc_measures, + cc_mid_measures=options.cc_mid_measures, + upright_lta=options.upright_lta, + orient_volume_lta=options.orient_volume_lta, + cc_surf=options.cc_surf, + cc_thickness_overlay=options.thickness_overlay, + cc_html=options.cc_html, + cc_surf_vtk=options.cc_surf_vtk, + segmentation_in_orig=options.segmentation_in_orig, + qc_image=options.qc_image, + thickness_image=options.thickness_image, + softlabels_cc=options.softlabels_cc, + softlabels_fn=options.softlabels_fn, + softlabels_background=options.softlabels_background, + ) diff --git a/CorpusCallosum/localization/__init__.py b/CorpusCallosum/localization/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/CorpusCallosum/localization/inference.py b/CorpusCallosum/localization/inference.py new file mode 100644 index 000000000..257c963a0 --- /dev/null +++ b/CorpusCallosum/localization/inference.py @@ -0,0 +1,252 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +from typing import Literal + +import numpy as np +import torch +from monai import transforms +from monai.networks.nets import DenseNet + +from CorpusCallosum.transforms.localization import CropAroundACPCFixedSize +from CorpusCallosum.utils.checkpoint import YAML_DEFAULT as CC_YAML +from CorpusCallosum.utils.types import Points2dType +from FastSurferCNN.download_checkpoints import load_checkpoint_config_defaults +from FastSurferCNN.download_checkpoints import main as download_checkpoints +from FastSurferCNN.utils import Image3d, Vector2d, Vector3d +from FastSurferCNN.utils.parser_defaults import FASTSURFER_ROOT + +PATCH_SIZE = (64, 64) + + +def load_model(device: torch.device) -> DenseNet: + """Load trained numerical localization model from checkpoint. + + Parameters + ---------- + device : torch.device + Device to load model to. + + Returns + ------- + DenseNet + Loaded and initialized model in evaluation mode. + """ + + # Initialize model architecture (must match training) + model = DenseNet( # densenet201 + spatial_dims=2, + in_channels=3, + out_channels=4, + init_features=64, + growth_rate=32, + block_config=(6, 12, 48, 32), + bn_size=4, + act=("relu", {"inplace": True}), + norm=("batch", {"affine": True}), + dropout_prob=0.2 + ) + + download_checkpoints(cc=True) + cc_config = load_checkpoint_config_defaults( + "checkpoint", + filename=CC_YAML, + ) + checkpoint_path = FASTSURFER_ROOT / cc_config['localization'] + + # Load state dict + if isinstance(checkpoint_path, str) or isinstance(checkpoint_path, Path): + state_dict = torch.load(checkpoint_path, map_location=device, weights_only=True) + if isinstance(state_dict, dict) and 'model_state_dict' in state_dict: + state_dict = state_dict['model_state_dict'] + else: + state_dict = checkpoint_path + + model.load_state_dict(state_dict) + model = model.to(device) + model.eval() + return model + + +def get_transforms() -> transforms.Compose: + """Get preprocessing transforms for inference. + + Returns + ------- + transforms.Compose + Composed transform pipeline including: + - Intensity scaling to [0,1] + - Fixed size cropping around AC-PC points + """ + tr = [ + transforms.ScaleIntensityd(keys=['image'], minv=0, maxv=1), + CropAroundACPCFixedSize(keys=['image'], fixed_size=PATCH_SIZE, random_translate=0), + ] + return transforms.Compose(tr) + + +def preprocess_volume( + image_volume: np.ndarray, + center_pt: Vector3d, + transform: transforms.Transform | None = None +) -> dict[str, torch.Tensor | tuple[int, ...]]: + """Preprocess a volume for inference. + + Parameters + ---------- + image_volume : np.ndarray + Input image volume of shape (W, W, D) in RAS. + center_pt : np.ndarray + Center point coordinates for cropping on the slice with shape (3,). + transform : transforms.Transform or None, optional + Custom transform pipeline, by default None. + If None, uses default transforms from get_transforms(). + + Returns + ------- + dict[str, torch.Tensor | tuple[int, ...]] + Dictionary containing preprocessed image tensor. + """ + if transform is None: + transform = get_transforms() + + # During training we used AC/PC coordinates, but during inference we approximate this by the center of the third + # ventricle. Therefore we put in the third ventricle center as dummy AC/PC coordinates for cropping the image. + sample = {"image": image_volume[None], "AC_center": center_pt[1:][None], "PC_center": center_pt[1:][None]} + + # Apply transforms + transformed = transform(sample) + + # Add batch dimension if needed + if torch.is_tensor(transformed["image"]): + if transformed["image"].ndim == 3: + transformed["image"] = transformed["image"].unsqueeze(0) + + return transformed + +def predict( + model: torch.nn.Module, + image_volume: Image3d, + patch_center: np.ndarray, + device: torch.device | None = None, + transform: transforms.Transform | None = None + ) -> tuple[Points2dType, Points2dType, tuple[int, int]]: + """ + Run inference on an image volume + + Parameters + ---------- + model : DenseNet + Trained model for inference. + image_volume : np.ndarray + Input volume as numpy array. + patch_center : np.ndarray + Initial center point estimate for cropping. + device : torch.device, optional + Device to run inference on, by default None. + transform : transforms.Transform, optional + Custom transform pipeline, defaults to preconfigured transforms of `get_transforms`. + + Returns + ------- + pc_ccord : np.ndarray + Predicted PC coordinates. + ac_coord : np.ndarray + Predicted AC coordinates. + crop_offsets : pair of ints + Crop offsets (left, top). + """ + if device is None: + device = next(model.parameters()).device + + # prepend zero to third_ventricle_center + patch_center_3d = np.concatenate([np.zeros(1), patch_center]) + + # Preprocess + t_dict = preprocess_volume(image_volume, patch_center_3d, transform) + + transformed_original = t_dict['image'] + inputs = transformed_original.to(device) + + inputs = inputs.transpose(0, 1) + inputs = inputs.unfold(0, 3, 1).transpose(1, -1)[..., 0] + + # Run inference + with torch.no_grad(): + outputs = model(inputs) * torch.as_tensor([PATCH_SIZE + PATCH_SIZE], device=device) + + t_crops = [(t_dict['crop_left'] + t_dict['crop_top']) * 2] + outs: np.ndarray[tuple[int, Literal[4]], np.dtype[np.float_]] = outputs.cpu().numpy() + np.asarray(t_crops, dtype=float) + crop_offsets: tuple[int, int] = (t_dict["crop_left"][0], t_dict["crop_top"][0]) + return outs[:, :2], outs[:, 2:], crop_offsets + + +def run_inference_on_slice( + model: DenseNet, + image_slab: Image3d, + center_pt: Vector2d, + num_iterations: int = 2, + debug_output: str | None = None, +) -> tuple[Vector2d, Vector2d]: + """Run inference on a single slice to detect AC and PC points. + + Parameters + ---------- + model : torch.nn.Module + Trained model for AC-PC detection. + image_slab : np.ndarray + 3D image mid-slices to run inference on in RAS. + center_pt : np.ndarray + Initial center point estimate for cropping. + num_iterations : int, default=2 + Number of refinement iterations to run. + debug_output : str, optional + Path to save debug visualization. + + Returns + ------- + ac_coords : np.ndarray + Detected AC voxel coordinates with shape (2,) containing its [y,x] positions. + pc_coords : np.ndarray + Detected PC voxel coordinates with shape (2,) containing its [y,x] positions. + """ + + if num_iterations < 1: + raise ValueError("localization inference with less than 1 iteration is invalid!") + + pc_coords, ac_coords = center_pt[None], center_pt[None] + crop_left, crop_top = 0, 0 + # Run inference + for _ in range(num_iterations): + pc_coords, ac_coords, (crop_left, crop_top) = predict(model, image_slab, center_pt) + center_pt = np.mean(np.stack([ac_coords, pc_coords], axis=0), axis=(0, 1)) + # average ac and pc coords across sagittal slices + _pc_coords = np.mean(pc_coords, axis=0) + _ac_coords = np.mean(ac_coords, axis=0) + + if debug_output is not None: + import matplotlib.pyplot as plt + from matplotlib.patches import Rectangle + fig, ax = plt.subplots(1, 1, figsize=(10, 8)) + ax.imshow(image_slab[image_slab.shape[0] // 2, :, :], cmap='gray') + # Plot points on all views + ax.scatter(pc_coords[:, 1], pc_coords[:, 0], c='r', marker='x', label='PC') + ax.scatter(ac_coords[:, 1], ac_coords[:, 0], c='b', marker='x', label='AC') + # make a box where the crop is + ax.add_patch(Rectangle((crop_top, crop_left), 64, 64, fill=False, color='r', linewidth=2)) + plt.savefig(debug_output, bbox_inches='tight') + plt.close() + + return _ac_coords, _pc_coords diff --git a/CorpusCallosum/paint_cc_into_pred.py b/CorpusCallosum/paint_cc_into_pred.py new file mode 100644 index 000000000..420abaeab --- /dev/null +++ b/CorpusCallosum/paint_cc_into_pred.py @@ -0,0 +1,348 @@ +# Copyright 2019 Image Analysis Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# IMPORTS + +import argparse +import sys +from functools import partial +from pathlib import Path +from typing import TypeVar, cast + +import nibabel as nib +import numpy as np +from numpy import typing as npt +from scipy import ndimage + +import FastSurferCNN.utils.logging as logging +from CorpusCallosum.data.constants import FORNIX_LABEL, SUBSEGMENT_LABELS +from FastSurferCNN.data_loader.conform import is_conform +from FastSurferCNN.reduce_to_aseg import reduce_to_aseg_and_save +from FastSurferCNN.utils.arg_types import path_or_none +from FastSurferCNN.utils.brainvolstats import mask_in_array +from FastSurferCNN.utils.parallel import thread_executor + +_T = TypeVar("_T", bound=np.number) + +logger = logging.get_logger(__name__) + +HELPTEXT = """ +Script to add corpus callosum segmentation (CC, FreeSurfer IDs 251-255) to +deep-learning prediction (e.g. aparc.DKTatlas+aseg.deep.mgz). + + +USAGE: +paint_cc_into_pred -in_cc -in_pred -out + + +Dependencies: + Python 3.8+ + + Nibabel to read and write FreeSurfer data + http://nipy.org/nibabel/ + +Original Author: Leonie Henschel +Date: Jul-10-2020 + +""" + + +def argument_parse(): + """Create a command line interface and return command line options. + """ + parser = make_parser() + + args = parser.parse_args() + + if args.input_cc is None or args.input_pred is None or args.output is None: + sys.exit("ERROR: Please specify input and output segmentations") + + return args + + +def make_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(usage=HELPTEXT) + parser.add_argument( + "--input_cc", + "-in_cc", + dest="input_cc", + type=Path, + required=True, + help="path to input segmentation with Corpus Callosum (IDs 251-255 in FreeSurfer space)", + ) + parser.add_argument( + "--input_pred", + "-in_pred", + dest="input_pred", + type=Path, + required=True, + help="path to input segmentation Corpus Callosum should be added to.", + ) + parser.add_argument( + "--output", + "-out", + dest="output", + type=Path, + required=True, + help="path to output (input segmentation + added CC)", + ) + parser.add_argument( + "--reduce_to_aseg", + "-aseg", + dest="aseg", + type=path_or_none, + required=False, + help="optionally also reduce the resulting segmentation to aseg and save separately.", + default=None, + ) + return parser + + +def paint_in_cc(pred: npt.NDArray[np.int_], + aseg_cc: npt.NDArray[np.int_]) -> npt.NDArray[np.int_]: + """Paint corpus callosum segmentation into aseg+dkt segmentation map. + + Parameters + ---------- + pred : npt.NDArray[np.int_] + Deep-learning segmentation map. + aseg_cc : npt.NDArray[np.int_] + Aseg segmentation with CC. + + Returns + ------- + npt.NDArray[np.int_] + Segmentation map with added CC. + + Notes + ----- + This function modifies the original array and does not create a copy. + The CC labels (251-255) from aseg_cc are copied into pred. + """ + cc_mask = mask_in_array(aseg_cc, SUBSEGMENT_LABELS) + pred[cc_mask] = aseg_cc[cc_mask] + return pred + +def correct_wm_ventricles( + aseg_cc: npt.NDArray[np.int_], + fornix_mask: npt.NDArray[np.bool_], + voxel_size: tuple[float, float, float], + close_gap_size_mm: float = 3.0 +) -> npt.NDArray[np.int_]: + """Correct WM mask and ventricle labels according to the CC and fornix masks. + + The function + Take non-CC-connected WM components -> remove + Take FN -> WM + Fill space in superior inferior direction between CC and left/right Ventricle with corresponding Ventricle labels + """ + + # Create a copy to avoid modifying the original + corrected_pred = aseg_cc.copy() + + # Get CC mask (labels 251-255) + cc_mask = mask_in_array(aseg_cc, SUBSEGMENT_LABELS) + + # Get left and right ventricle masks + all_ventricle_mask = (aseg_cc == 4) | (aseg_cc == 43) + + # Combine all WM labels + all_wm_mask = (aseg_cc == 2) | (aseg_cc == 41) + + # 1. Fill space between CC and ventricles + # Only fill small gaps (up to 3 voxels) between CC and ventricle boundaries + #for ventricle_label, ventricle_mask in [(4, left_ventricle_mask), (43, right_ventricle_mask)]: + + # Process each slice independently + for x in range(corrected_pred.shape[0]): + cc_slice = cc_mask[x] + #vent_slice = ventricle_mask + all_wm_slice = all_wm_mask[x] + + if all_wm_slice.any() and cc_slice.any(): + + # Dilate CC mask to find adjacent voxels, then check for overlap with component + cc_dilated = ndimage.binary_dilation(cc_slice, iterations=1) + # Label connected components in WM + labeled_wm, num_components = ndimage.label(all_wm_slice) + + # Find components that are adjacent to CC and remove them + for label in range(1, num_components + 1): + component_mask = labeled_wm == label + # Check if this component is adjacent to (touches) the CC + if np.any(component_mask & cc_dilated): + corrected_pred[x][component_mask] = 0 # Set to background + + if fornix_mask[x].any(): + fornix_slice = fornix_mask[x] + # count WM labels overlapping with fornix + left_wm_overlap = np.sum(fornix_slice & (aseg_cc == 2)) + right_wm_overlap = np.sum(fornix_slice & (aseg_cc == 41)) + corrected_pred[x][fornix_slice] = 2 + (left_wm_overlap > right_wm_overlap) * 39 # Left WM / Right WM + + vent_slice = all_ventricle_mask + potential_fill = np.asarray([False]) + if cc_slice.any() and vent_slice.any(): + # Create binary masks for this slice + cc_binary = cc_slice.astype(bool) + vent_binary = vent_slice.astype(bool) + + # Dilate both masks slightly to find potential connection points + max_gap_vox = int(np.ceil(voxel_size[1] * close_gap_size_mm)) + cc_dilated = ndimage.binary_dilation(cc_binary, iterations=max_gap_vox) + vent_dilated = ndimage.binary_dilation(vent_binary, iterations=max_gap_vox) + + # Find voxels that are adjacent to both CC and ventricle but not part of either + potential_fill = (cc_dilated & vent_dilated) & ~(cc_binary | vent_binary) + + # Only fill small gaps between CC and ventricle in inferior-superior direction + if not potential_fill.any(): + for z in range(potential_fill.shape[1]): + potential_fill_line = potential_fill[:, z] + labeled_gaps, num_gaps = ndimage.label(potential_fill_line) + cc_line = cc_binary[:, z] + vent_line = vent_binary[:, z] + + for gap_label in range(1, num_gaps + 1): + gap_mask = labeled_gaps == gap_label + + # check that CC and ventricle are connected to the gap_mask + dilated_gap_mask = ndimage.binary_dilation(gap_mask, iterations=1) + if not np.any(cc_line & dilated_gap_mask): + continue + if not np.any(vent_line & dilated_gap_mask): + continue + + vent_label_location = np.where(vent_line & dilated_gap_mask)[0] + vent_label = corrected_pred[x, vent_label_location, z] + + if np.sum(gap_mask) > max_gap_vox: + continue + + corrected_pred[x, :, z][gap_mask & (corrected_pred[x, :, z] == 0)] = vent_label + + # Process gaps in z-direction (within each y-row) + for y in range(potential_fill.shape[0]): + potential_fill_line = potential_fill[y, :] + labeled_gaps, num_gaps = ndimage.label(potential_fill_line) + cc_line = cc_binary[y, :] + vent_line = vent_binary[y, :] + + for gap_label in range(1, num_gaps + 1): + gap_mask = labeled_gaps == gap_label + + # check that CC and ventricle are connected to the gap_mask + dilated_gap_mask = ndimage.binary_dilation(gap_mask, iterations=1) + if not np.any(cc_line & dilated_gap_mask): + continue + if not np.any(vent_line & dilated_gap_mask): + continue + + vent_label_location = np.where(vent_line & dilated_gap_mask)[0] + if len(vent_label_location) > 0: + vent_label = corrected_pred[x, y, vent_label_location[0]] # Take first match + + if np.sum(gap_mask) > max_gap_vox: + continue + + corrected_pred[x, y, :][gap_mask & (corrected_pred[x, y, :] == 0)] = vent_label + + return corrected_pred + + +if __name__ == "__main__": + from FastSurferCNN.utils import nibabelImage + + # Command Line options are error checking done here + options = argument_parse() + + logging.setup_logging() + + logger.info(f"Reading inputs: {options.input_cc} {options.input_pred}...") + cc_seg_image = cast(nibabelImage, nib.load(options.input_cc)) + cc_seg_data = np.asanyarray(cc_seg_image.dataobj) + aseg_image = cast(nibabelImage, nib.load(options.input_pred)) + aseg_data = np.asanyarray(aseg_image.dataobj) + + def _is_conform(img, dtype, verbose): + return is_conform(img, vox_size=None, img_size=None, verbose=verbose, dtype=dtype) + + conform_args = (cc_seg_image, aseg_image), (np.uint8, np.integer) + conform_checks = list(thread_executor().map(partial(_is_conform, verbose=False), *conform_args)) + + if not all(conform_checks): + names = [] + dtypes = [] + for conform_check, img, dtype, name in zip(conform_checks, *conform_args, ("CC", "Prediction"), strict=True): + if not conform_check: + _is_conform(img, dtype, verbose=True) + names.append(name) + dtypes.append(dtype.name if hasattr(dtype, "name") else str(dtype)) + sys.exit( + f"Error: {' and '.join(names)} input image is not conformed (LIA orientation, {'/'.join(dtypes)} dtype). " + "Please conform the image(s) using the conform.py script." + ) + if not np.allclose(cc_seg_image.affine, aseg_image.affine): + sys.exit("Error: The affine matrices of the aseg and the corpus callosum images are not the same.") + + # Paint CC into prediction + pred_with_cc = paint_in_cc(aseg_data, cc_seg_data) + + # Apply WM and ventricle corrections + logger.info("Applying white matter and ventricle corrections...") + fornix_mask = cc_seg_data == FORNIX_LABEL + voxel_size = tuple(aseg_image.header.get_zooms()) + pred_corrected = correct_wm_ventricles(aseg_data, fornix_mask, voxel_size) + + logger.info(f"Writing segmentation with corpus callosum to: {options.output}") + pred_with_cc_fin = nib.MGHImage(pred_corrected, aseg_image.affine, aseg_image.header) + io_fut = thread_executor().submit(pred_with_cc_fin.to_filename, options.output) + + if options.aseg is not None: + rta_fut = thread_executor().submit( + reduce_to_aseg_and_save, + pred_corrected, + aseg_image.affine, + aseg_image.header, + options.aseg, + ) + else: + rta_fut = None + + # Count initial labels + initial_cc = np.sum(mask_in_array(aseg_data, SUBSEGMENT_LABELS)) + initial_fornix = np.sum(aseg_data == FORNIX_LABEL) + initial_wm = np.sum((aseg_data == 2) | (aseg_data == 41)) + logger.info(f"Initial segmentation: CC={initial_cc}, Fornix={initial_fornix}, WM={initial_wm}") + + after_paint_cc = np.sum(mask_in_array(pred_with_cc, SUBSEGMENT_LABELS)) + logger.info(f"After painting CC: {after_paint_cc} CC voxels added") + + # Count final labels + final_cc = np.sum(mask_in_array(pred_corrected, SUBSEGMENT_LABELS)) + final_fornix = np.sum(pred_corrected == FORNIX_LABEL) + final_wm = np.sum((pred_corrected == 2) | (pred_corrected == 41)) + final_ventricles = np.sum((pred_corrected == 4) | (pred_corrected == 43)) + + logger.info(f"Final segmentation: CC={final_cc}, Fornix={final_fornix},\ + WM={final_wm}, Ventricles={final_ventricles}") + logger.info(f"Changes: CC +{final_cc-initial_cc}, Fornix {final_fornix-initial_fornix},\ + WM {final_wm-initial_wm}") + + if rta_fut is not None: + _ = rta_fut.result() + + sys.exit(0) + diff --git a/CorpusCallosum/segmentation/__init__.py b/CorpusCallosum/segmentation/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/CorpusCallosum/segmentation/inference.py b/CorpusCallosum/segmentation/inference.py new file mode 100644 index 000000000..1ae1e6063 --- /dev/null +++ b/CorpusCallosum/segmentation/inference.py @@ -0,0 +1,309 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Iterator +from pathlib import Path +from typing import cast + +import nibabel as nib +import numpy as np +import torch +from monai import transforms +from numpy import typing as npt +from typing_extensions import overload + +from CorpusCallosum.data import constants +from CorpusCallosum.transforms.segmentation import CropAroundACPC +from CorpusCallosum.utils.checkpoint import YAML_DEFAULT as CC_YAML +from FastSurferCNN.download_checkpoints import load_checkpoint_config_defaults +from FastSurferCNN.download_checkpoints import main as download_checkpoints +from FastSurferCNN.models.networks import FastSurferVINN +from FastSurferCNN.utils import Image3d, Image4d, Shape2d, Shape3d, Shape4d, Vector2d, nibabelImage +from FastSurferCNN.utils.parallel import thread_executor + + +def load_model(device: torch.device | None = None) -> FastSurferVINN: + """Load trained model from checkpoint. + + Parameters + ---------- + device : torch.device or None, optional + Device to load model to, by default None. + If None, uses CUDA if available, else CPU. + + Returns + ------- + FastSurferVINN + Loaded and initialized model in evaluation mode. + """ + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + params = { + "num_classes": 3, + "num_filters": 71, + "num_filters_interpol": 32, + "num_channels": 9, + "kernel_h": 3, + "kernel_w": 3, + "kernel_c": 1, + "stride_conv": 1, + "stride_pool": 2, + "pool": 2, + "height": 128, + "width": 128, + "base_res": 1.0, + "interpolation_mode": "bilinear", + "crop_position": "top_left", + "out_tensor_width": 320, + "out_tensor_height": 320, + } + model = FastSurferVINN(params) + + download_checkpoints(cc=True) + cc_config: dict[str, Path] = load_checkpoint_config_defaults( + "checkpoint", + filename=CC_YAML, + ) + checkpoint_path = constants.FASTSURFER_ROOT / cc_config['segmentation'] + + weights = torch.load(checkpoint_path, weights_only=True, map_location=device) + model.load_state_dict(weights) + model.eval() + model.to(device) + return model + + +def run_inference( + model: "torch.nn.Module", + image_slice: Image3d, + ac_center: Vector2d, + pc_center: Vector2d, + voxel_size: tuple[float, float], + device: torch.device | None = None, + transform: transforms.Transform | None = None +) -> tuple[np.ndarray[Shape4d, np.dtype[np.int_]], Image4d, Image4d]: + """Run inference on a single image slice. + + Parameters + ---------- + model : torch.nn.Module + Trained model. + image_slice : np.ndarray + LIA-oriented input image as numpy array of shape (L, I, A). + ac_center : np.ndarray + Anterior commissure coordinates. + pc_center : np.ndarray + Posterior commissure coordinates. + voxel_size : a pair of floats + Voxel size of inferior/superior and anterior/posterior direction in mm. + device : torch.device, optional + Device to run inference on. If None, uses the device of the model. + transform : transforms.Transform, optional + Custom transform pipeline. + + Returns + ------- + seg_labels : npt.NDArray[int] + The segmentation result. + inputs : npt.NDArray[float] + The inputs to the model. + soft_labels : npt.NDArray[float] + The softlabel output. + """ + if device is None: + device = next(model.parameters()).device + + crop_around_acpc = CropAroundACPC(keys=['image'], padding_mm=35, random_translate=0) + to_discrete = transforms.AsDiscrete(argmax=True, to_onehot=3) + + # Preprocess slice + _inputs = torch.from_numpy(image_slice[:,None]) #,:256,:256]) # artifact from training script + sample = {'image': _inputs, 'AC_center': ac_center, 'PC_center': pc_center, 'res': np.asarray(voxel_size)} + sample_cropped = crop_around_acpc(sample) + _inputs, to_pad = sample_cropped['image'], sample_cropped['to_pad'] + _inputs = transforms.utils.rescale_array(_inputs, 0, 1, dtype=np.float32).to(device) + + # split into slices with 9 channels each + # Generate views with sliding window of 9 slices + batch_size, channels, height, width = _inputs.shape + _inputs = _inputs.unfold(0, 9, 1).swapdims(-1, 1).reshape(-1, 9*channels, height, width) + + # Post-process outputs + with torch.no_grad(): + scale_factors = torch.ones((_inputs.shape[0], 2), device=device) / torch.asarray([voxel_size], device=device) + + _logits = model(_inputs, scale_factor=scale_factors) + _softlabels = transforms.Activations(softmax=True, dim=1)(_logits) + + softlabels = _softlabels.cpu().numpy() + _labels = torch.stack([to_discrete(i) for i in _softlabels]) + + # Pad back to original size, to_pad is a tuple[int, int, int, int] + pad_tuples = ((0, 0),) * 2 + (to_pad[:2], to_pad[2:]) + labels = np.pad(_labels.cpu().numpy(), pad_tuples, mode='constant', constant_values=0) + softlabels = np.pad(softlabels, pad_tuples, mode='constant', constant_values=0) + + return tuple(x.transpose(0, 2, 3, 1) for x in (labels, _inputs.cpu().numpy(), softlabels)) + + +def load_validation_data( + path: str | Path, +) -> tuple[npt.NDArray[str], npt.NDArray[float], npt.NDArray[float], Iterator[int], npt.NDArray[str], list[str]]: + """Load validation data from CSV file and compute label widths. + + Reads a CSV file containing image paths, label paths, and AC/PC coordinates, + then computes the width (number of slices with non-zero labels) for each label file. + + Parameters + ---------- + path : str or Path + Path to the CSV file containing validation data. The CSV should have columns: + image, label, AC_center_x, AC_center_y, AC_center_z, + PC_center_x, PC_center_y, PC_center_z. + + Returns + ------- + images : npt.NDArray[str] + Array of image file paths. + ac_centers : npt.NDArray[float] + Array of anterior commissure coordinates (x, y, z). + pc_centers : npt.NDArray[float] + Array of posterior commissure coordinates (x, y, z). + label_widths : Iterator[int] + Iterator yielding the number of slices with non-zero labels for each label file. + labels : npt.NDArray[str] + Array of label file paths. + subj_ids : list[str] + List of subject IDs (from CSV index). + """ + import pandas as pd + + data = pd.read_csv(path, index_col=0, header=None) + data.columns = ["image", "label", "AC_center_x", "AC_center_y", "AC_center_z", + "PC_center_x", "PC_center_y", "PC_center_z"] + + ac_centers = data[["AC_center_x", "AC_center_y", "AC_center_z"]].values + pc_centers = data[["PC_center_x", "PC_center_y", "PC_center_z"]].values + images = data["image"].values + labels = data["label"].values + subj_ids = data.index.values.tolist() + + def _load(label_path: str | Path) -> int: + """Compute the width of non-zero slices in a label image. + + Parameters + ---------- + label_path : str or Path + Path to the label image file + + Returns + ------- + int + Number of slices containing non-zero labels, or total slices if <= 100 + """ + label_img = cast(nibabelImage, nib.load(label_path)) + + if label_img.shape[0] > 100: + # check which slices have non-zero values + label_data = np.asarray(label_img.dataobj) + non_zero_slices = np.any(label_data > 0, axis=(1,2)) + first_nonzero = np.argmax(non_zero_slices) + last_nonzero = len(non_zero_slices) - np.argmax(non_zero_slices[::-1]) + return last_nonzero - first_nonzero + else: + return label_img.shape[0] + + label_widths = thread_executor().map(_load, data["label"]) + + return images, ac_centers, pc_centers, label_widths, labels, subj_ids + +@overload +def one_hot_to_label(one_hot: Image4d, label_ids: list[int] | None = None) -> np.ndarray[Shape3d, np.dtype[np.int_]]: ... + +@overload +def one_hot_to_label(one_hot: Image3d, label_ids: list[int] | None = None) -> np.ndarray[Shape2d, np.dtype[np.int_]]: ... + +def one_hot_to_label( + one_hot: np.ndarray[tuple[int, ...], np.dtype[np.bool_]], + label_ids: list[int] | None = None, +) -> np.ndarray[tuple[int, ...], np.dtype[np.int_]]: + """Convert one-hot encoded segmentation to label map. + + Converts a one-hot encoded segmentation array to discrete labels by taking + the argmax along the last axis and optionally mapping to specific label values. + + Parameters + ---------- + one_hot : np.ndarray of floats + One-hot encoded segmentation array of shape (..., num_classes). + label_ids : array_like of ints, optional + List of label IDs to map classes to. If None, defaults to [0, FORNIX_LABEL, CC_LABEL]. + The index in this list corresponds to the class index from argmax. + + Returns + ------- + npt.NDArray[int] + Label map with discrete integer labels. + """ + if label_ids is None: + from CorpusCallosum.data.constants import CC_LABEL, FORNIX_LABEL + label_ids = [0, FORNIX_LABEL, CC_LABEL] + + label = np.argmax(one_hot, axis=3) + if label_ids is not None: + label = np.asarray(label_ids)[label] + + return label + + +def run_inference_on_slice( + model: "torch.nn.Module", + test_slab: Image3d, + ac_center: Vector2d, + pc_center: Vector2d, + voxel_size: tuple[float, float], +) -> tuple[np.ndarray[Shape3d, np.dtype[np.int_]], Image4d, Image4d]: + """Run inference on a single slice. + + Parameters + ---------- + model : torch.nn.Module + Trained model for inference. + test_slab : np.ndarray + Input image slice. + ac_center : npt.NDArray[float] + Anterior commissure coordinates (Inferior and Anterior values). + pc_center : npt.NDArray[float] + Posterior commissure coordinates (Inferior and Posterior values). + voxel_size : a pair of floats + Voxel sizes in superior/inferior and anterior/posterior direction in mm. + + Returns + ------- + results: np.ndarray + Label map after one-hot conversion. + inputs: np.ndarray + Preprocessed input image. + outputs_soft: npt.NDArray[float] + Softlabel outputs (non-discrete). + + """ + # add zero in front of AC_center and PC_center + ac_center = np.concatenate([np.zeros(1), ac_center]) + pc_center = np.concatenate([np.zeros(1), pc_center]) + + _results, inputs, outputs_soft = run_inference(model, test_slab, ac_center, pc_center, voxel_size) + results = one_hot_to_label(_results) + + return results, inputs, outputs_soft diff --git a/CorpusCallosum/segmentation/segmentation_postprocessing.py b/CorpusCallosum/segmentation/segmentation_postprocessing.py new file mode 100644 index 000000000..fb62ca94b --- /dev/null +++ b/CorpusCallosum/segmentation/segmentation_postprocessing.py @@ -0,0 +1,542 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from numpy import typing as npt +from scipy import integrate, ndimage +from scipy.spatial.distance import cdist +from skimage.measure import label + +import FastSurferCNN.utils.logging as logging +from CorpusCallosum.data.constants import CC_LABEL +from FastSurferCNN.utils import Mask3d, Shape3d + +logger = logging.get_logger(__name__) + + +def find_component_boundaries(labels_arr: npt.NDArray[int], component_id: int) -> npt.NDArray[int]: + """Find boundary voxels of a connected component. + + Parameters + ---------- + labels_arr : np.ndarray + Labeled array from connected components analysis. + component_id : int + ID of the component to find boundaries for. + + Returns + ------- + np.ndarray + Array of shape (N, 3) containing boundary coordinates. + + Notes + ----- + Uses 6-connectivity (face neighbors only) to determine boundaries. + Boundary voxels are those that are part of the component but have + at least one non-component neighbor. + """ + component_mask = labels_arr == component_id + + # Create a structuring element for 6-connectivity (face neighbors only) + struct = ndimage.generate_binary_structure(3, 1) + + # Erode the component to find internal voxels + eroded = ndimage.binary_erosion(component_mask, structure=struct) + + # Boundary is the difference between original and eroded + boundary = component_mask & ~eroded + + return np.array(np.where(boundary)).T + + +def find_minimal_connection_path( + boundary_coords1: np.ndarray, + boundary_coords2: np.ndarray, + max_distance: float = 3.0 +) -> tuple[np.ndarray, np.ndarray] | None: + """Find the minimal connection path between two component boundaries. + + Parameters + ---------- + boundary_coords1 : np.ndarray + Boundary coordinates of first component, shape (N1, 3). + boundary_coords2 : np.ndarray + Boundary coordinates of second component, shape (N2, 3). + max_distance : float, default=3.0 + Maximum distance to consider for connection, by default 3.0. + + Returns + ------- + tuple[np.ndarray, np.ndarray] or None + If a valid connection is found: + + - point1 : Coordinates on first boundary + - point2 : Coordinates on second boundary + + None if no connection within max_distance is found. + + Notes + ----- + Uses Euclidean distance to find the closest pair of points + between the two boundaries. + """ + if len(boundary_coords1) == 0 or len(boundary_coords2) == 0: + return None + + # Calculate pairwise distances between all boundary points + distances = cdist(boundary_coords1, boundary_coords2, metric='euclidean') + + # Find the minimum distance and corresponding points + min_idx = np.unravel_index(np.argmin(distances), distances.shape) + min_distance = distances[min_idx] + + if min_distance <= max_distance: + return boundary_coords1[min_idx[0]], boundary_coords2[min_idx[1]] + + return None + + +def create_connection_line(point1: np.ndarray, point2: np.ndarray) -> list[tuple[int, int, int]]: + """Create a line of voxels connecting two points. + + Uses a simplified 3D line algorithm to create a sequence of voxels + that form a continuous path between the two points. + + Parameters + ---------- + point1 : np.ndarray + Starting point coordinates, shape (3,). + point2 : np.ndarray + Ending point coordinates, shape (3,). + + Returns + ------- + list[tuple[int, int, int]] + List of (x, y, z) coordinates forming the connection line. + + Notes + ----- + The line is created by interpolating between the points using + the maximum distance in any dimension as the number of steps. + """ + x1, y1, z1 = map(int, point1) + x2, y2, z2 = map(int, point2) + + line_points = [] + + # Calculate the number of steps needed + dx = abs(x2 - x1) + dy = abs(y2 - y1) + dz = abs(z2 - z1) + + steps = max(dx, dy, dz) + + if steps == 0: + return [(x1, y1, z1)] + + # Calculate increments for each dimension + x_inc = (x2 - x1) / steps + y_inc = (y2 - y1) / steps + z_inc = (z2 - z1) / steps + + # Generate points along the line + for i in range(steps + 1): + x = int(round(x1 + i * x_inc)) + y = int(round(y1 + i * y_inc)) + z = int(round(z1 + i * z_inc)) + line_points.append((x, y, z)) + + return line_points + + +def connect_nearby_components(seg_arr: np.ndarray, max_connection_distance: float = 3.0) -> np.ndarray: + """Connect nearby disconnected components that should be connected. + + This function identifies disconnected components in the segmentation and creates + minimal connections between components that are close to each other. + + Parameters + ---------- + seg_arr : np.ndarray + Input binary segmentation array. + max_connection_distance : float, optional + Maximum distance to connect components, by default 3.0. + + Returns + ------- + np.ndarray + Segmentation array with minimal connections added between nearby components. + + Notes + ----- + The function: + 1. Identifies connected components in the input segmentation + 2. Finds boundaries between components + 3. Creates minimal connections between nearby components + 4. Returns the modified segmentation with added connections + """ + + # Create a copy to modify + connected_seg = seg_arr.copy() + + # Find connected components without dilation first + labels_cc = label(seg_arr, connectivity=3, background=0) + + # Get component sizes (excluding background) + bincount = np.bincount(labels_cc.flat) + component_ids = np.where(bincount > 0)[0][1:] # Exclude background (0) + + if len(component_ids) <= 1: + return connected_seg # Only one component, no connections needed + + # Sort components by size (largest first) + component_sizes = [(comp_id, bincount[comp_id]) for comp_id in component_ids] + component_sizes.sort(key=lambda x: x[1], reverse=True) + + # Use the largest component as the reference + main_component_id = component_sizes[0][0] + + logger.info(f"Found {len(component_ids)} disconnected components. " + f"Attempting to connect smaller components to main component (size: {component_sizes[0][1]})") + + connections_made = 0 + + # Try to connect each smaller component to the main component + for comp_id, comp_size in component_sizes[1:]: + if comp_size < 5: # Skip very small components (likely noise) + logger.debug(f"Skipping tiny component {comp_id} with size {comp_size}") + continue + + # Find boundaries of both components + main_boundary = find_component_boundaries(labels_cc, main_component_id) + comp_boundary = find_component_boundaries(labels_cc, comp_id) + + # Find minimal connection path + connection = find_minimal_connection_path(main_boundary, comp_boundary, max_connection_distance) + + if connection is not None: + point1, point2 = connection + distance = np.linalg.norm(point2 - point1) + + logger.debug(f"Connecting component {comp_id} (size: {comp_size}) to main component. " + f"Distance: {distance:.2f} voxels") + + # Create connection line + connection_line = create_connection_line(point1, point2) + + # Add connection voxels to the segmentation + # Use the same label as the original segmentation at the connection points + connection_label = seg_arr[point1[0], point1[1], point1[2]] if \ + seg_arr[point1[0], point1[1], point1[2]] != 0 else \ + seg_arr[point2[0], point2[1], point2[2]] + + for x, y, z in connection_line: + if (0 <= x < connected_seg.shape[0] and + 0 <= y < connected_seg.shape[1] and + 0 <= z < connected_seg.shape[2]): + if connected_seg[x, y, z] == 0: # Only fill empty voxels + connected_seg[x, y, z] = connection_label + + connections_made += 1 + else: + logger.debug(f"Component {comp_id} (size: {comp_size}) too far from main component") + + logger.info(f"Created {connections_made} minimal connections between components") + + + # Plot components for visualization + # import matplotlib.pyplot as plt + # n_components = len(component_sizes) + # fig, axes = plt.subplots(1, n_components + 1, figsize=(5*(n_components + 1), 5)) + # if n_components == 1: + # axes = [axes] + # # Plot each component in a different color + # for i, (comp_id, comp_size) in enumerate(component_sizes): + # component_mask = labels_cc == comp_id + # axes[i].imshow(component_mask[component_mask.shape[0]//2], cmap='gray') + # axes[i].set_title(f'Component {comp_id}\nSize: {comp_size}') + # axes[i].axis('off') + + # # Plot the connected segmentation + # axes[-1].imshow(connected_seg[connected_seg.shape[0]//2], cmap='gray') + # axes[-1].set_title('Connected Segmentation') + # axes[-1].axis('off') + # plt.tight_layout() + # plt.show() + + return connected_seg + + +def get_cc_volume_voxel( + desired_width_mm: int, + cc_mask: np.ndarray, + voxel_size: tuple[float, float, float], +) -> float: + """Calculate the volume of the corpus callosum in cubic millimeters. + + This function calculates the volume of the corpus callosum (CC) in cubic millimeters. + If the CC width is larger than desired_width_mm, the voxels on the edges are calculated as + partial volumes to achieve the desired width. + + Parameters + ---------- + desired_width_mm : int + Desired width of the CC in millimeters. + cc_mask : np.ndarray + Binary mask of the corpus callosum in LIA orientation. + voxel_size : triplet of floats + LIA-oriented Voxel size in millimeters (x, y, z). + + Returns + ------- + float + Volume of the CC in cubic millimeters. + + Raises + ------ + ValueError + If CC width is smaller than desired width + AssertionError + If CC mask doesn't have odd number of voxels in x dimension + + Notes + ----- + The function assumes LIA orientation where: + - x dimension corresponds to Left/Right + - y dimension corresponds to Inferior/Superior + - z dimension corresponds to Anterior/Posterior + """ + assert cc_mask.shape[0] % 2 == 1, "CC mask must have odd number of voxels in x dimension" + + + # Calculate voxel volume + voxel_volume: float = np.prod(voxel_size, dtype=float) + voxel_width: float = voxel_size[0] + + # Get width of CC mask in voxels by finding the extent in x dimension + width_vox = np.sum(np.any(cc_mask, axis=(1,2))) + + # we are in LIA, so 0 is L/R resolution + width_mm = width_vox * voxel_width + + if width_mm == desired_width_mm: + return np.sum(cc_mask) * voxel_volume + elif width_mm > desired_width_mm: + # remainder on the left/right side of the CC mask + desired_width_vox = desired_width_mm / voxel_width + fraction_of_voxel_at_edge = (desired_width_vox % 1) / 2 + + if fraction_of_voxel_at_edge > 0: + # make sure the assumentation is correct that the CC mask has an odd number of voxels + # and the leftmost and rightmost voxels are the edges at the desired width + cc_width_vox = int(np.floor(desired_width_vox) + 1) + cc_width_vox = cc_width_vox + 1 if cc_width_vox % 2 == 0 else cc_width_vox + + assert cc_mask.shape[0] == cc_width_vox, (f"CC mask should have {cc_width_vox} voxels, " + f"but has {cc_mask.shape[0]}") + + left_partial_volume = np.sum(cc_mask[0]) * voxel_volume * fraction_of_voxel_at_edge + right_partial_volume = np.sum(cc_mask[-1]) * voxel_volume * fraction_of_voxel_at_edge + center_volume = np.sum(cc_mask[1:-1]) * voxel_volume + return left_partial_volume + right_partial_volume + center_volume + else: + raise ValueError(f"Width of CC segmentation is smaller than desired width: {width_mm} < {desired_width_mm}") + + +def get_cc_volume_contour( + cc_contours: list[np.ndarray], + voxel_size: tuple[float, float, float], +) -> float: + """Calculate the volume of the corpus callosum using Simpson's rule. + + Parameters + ---------- + cc_contours : list[np.ndarray] + List of CC contours for each slice in the left-right direction. + voxel_size : triplet of floats + Voxel size in millimeters (x, y, z). + + Returns + ------- + float + Volume of the CC in cubic millimeters. + + Raises + ------ + ValueError + If CC width is smaller than desired width or insufficient contours for Simpson's rule + + Notes + ----- + This function calculates the volume of the corpus callosum (CC) in cubic millimeters + using Simpson's rule. If the CC width is larger than desired_width_mm, the voxels on + the edges are calculated as partial volumes to achieve the desired width. + """ + # FIXME: This function is a shape-tool, it should therefore not be in segmentation.postprocessing... + # FIXME: this code currently produces volume estimates more that 50% off of the volume_based estimate in + # get_cc_volume_voxel... + + if len(cc_contours) < 3: + raise ValueError("Need at least 3 contours for Simpson's rule integration") + + # FIXME: why can we not multiply by those numbers in line below other FIXME comment + # converting this to a warning for now... + if voxel_size[1] == voxel_size[2]: + logger.warning("voxel sizes in get_cc_volume_contour, currently volume must be isotropic!") + # Calculate cross-sectional areas for each contour + areas = [] + for contour in cc_contours: + # Calculate area using the shoelace formula for polygon area + if contour.shape[1] < 3: + areas.append(0.0) + else: + # FIXME: we are multiplying by voxel size here and below "Convert from voxel^2 to mm^2", e.g. + # x = contour[0] * voxel_size[1] + # y = contour[1] * voxel_size[2] + contour = contour * voxel_size[1] + x = contour[0] + y = contour[1] + # Shoelace formula: A = 0.5 * |sum(x_i * y_{i+1} - x_{i+1} * y_i)| + area = 0.5 * np.abs(np.sum(x[:-1] * y[1:] - x[1:] * y[:-1])) + # Convert from voxel^2 to mm^2 + area_mm2 = area * voxel_size[1] * voxel_size[2] # y * z voxel dimensions + areas.append(area_mm2) + + areas = np.array(areas) + + # Calculate spacing between slices (left-right direction) + lr_spacing = voxel_size[0] # x-direction voxel size + + measurement_points = np.arange(-voxel_size[0]*(areas.shape[0]//2), + voxel_size[0]*((areas.shape[0]+1)//2), lr_spacing) + + # FIXME: why interpolate at 0.25? Also, why do we need interpolaton at all? + # interpolate areas at 0.25 and 5 + areas_interpolated = np.interp(x=[-2.5, 2.5], + xp=measurement_points, + fp=areas) + + # remove measurement points that are outside of the desired range + # not sure if this can happen, but let's be safe + outside_range = (measurement_points < -2.5) | (measurement_points > 2.5) + measurement_points = [-2.5] + measurement_points[~outside_range].tolist() + [2.5] + areas = [areas_interpolated[0]] + areas[~outside_range].tolist() + [areas_interpolated[1]] + + + # can also use trapezoidal rule + return integrate.simpson(areas, x=measurement_points) + + +def extract_largest_connected_component( + seg_arr: Mask3d, + max_connection_distance: float = 3.0, +) -> Mask3d: + """Get largest connected component from a binary segmentation array. + + Parameters + ---------- + seg_arr : np.ndarray + Input binary segmentation array. + max_connection_distance : float, optional + Maximum distance to connect components, by default 3.0. + + Returns + ------- + np.ndarray + Binary mask of the largest connected component. + + Notes + ----- + The function first attempts to connect nearby disconnected components + that should be connected, then finds the largest connected component. + It uses minimal connections between close components before falling + back to dilation if no connections are made. + """ + # First attempt: try to connect nearby components with minimal connections + connected_seg = connect_nearby_components(seg_arr, max_connection_distance) + + # Check if connections were successful by comparing connectivity + original_labels = label(seg_arr, connectivity=3, background=0) + connected_labels = label(connected_seg, connectivity=3, background=0) + + original_components = len(np.unique(original_labels)) - 1 # Exclude background + connected_components = len(np.unique(connected_labels)) - 1 # Exclude background + + if connected_components < original_components: + logger.info(f"Successfully reduced components from {original_components} to {connected_components} " + "using minimal connections") + mask = connected_seg + # else: + # logger.info("No connections made, falling back to dilation approach") + # # Fallback: use the original dilation approach + # struct1 = ndimage.generate_binary_structure(3, 3) + # mask = ndimage.binary_dilation(seg_arr, structure=struct1, iterations=1).astype(np.uint8) + + # Get connected components from the processed mask + labels_cc = label(mask, connectivity=3, background=0) + + # Get component counts + bincount = np.bincount(labels_cc.flat) + + # Get background label (assumed to be the largest component) + background = np.argmax(bincount) + bincount[background] = -1 + + # Get largest connected component + largest_cc = np.equal(labels_cc, np.argmax(bincount)) + + return largest_cc + + +def clean_cc_segmentation( + seg_arr: np.ndarray[Shape3d, np.dtype[np.int_]], + max_connection_distance: float = 3.0, +) -> tuple[np.ndarray[Shape3d, np.dtype[np.int_]], Mask3d]: + """Clean corpus callosum segmentation by removing non-connected components. + + Parameters + ---------- + seg_arr : npt.NDArray[int] + Input segmentation array with CC (192) and fornix (250) labels. + max_connection_distance : float, default=3.0 + Maximum distance to connect components. + + Returns + ------- + clean_seg : np.NDArray[int] + Cleaned segmentation array with only the largest connected component of CC and fornix. + mask : npt.NDArray[bool] + Binary mask of the largest connected component. + + Notes + ----- + The function: + 1. Isolates the CC (label 192) + 2. Attempts to connect nearby disconnected components + 3. Adds the fornix (label 250) + 4. Removes non-connected components from the combined CC and fornix + """ + from functools import partial + + extract_largest = partial(extract_largest_connected_component, max_connection_distance=max_connection_distance) + + # Remove non-connected components from the CC alone, with minimal connections + mask = np.equal(seg_arr, CC_LABEL) + cc_seg = mask.astype(int) * CC_LABEL + cc_label_cleaned = np.concatenate([extract_largest(seg[None]) * CC_LABEL for seg in cc_seg], axis=0) + + # Add fornix to the CC labels + clean_seg = np.where(mask, cc_label_cleaned, seg_arr) + + return clean_seg, np.greater(cc_label_cleaned, 0) diff --git a/CorpusCallosum/shape/__init__.py b/CorpusCallosum/shape/__init__.py new file mode 100644 index 000000000..4950a2427 --- /dev/null +++ b/CorpusCallosum/shape/__init__.py @@ -0,0 +1,15 @@ + +from CorpusCallosum.shape import endpoint_heuristic, mesh, metrics, postprocessing, subsegment_contour, thickness +from CorpusCallosum.shape.contour import CCContour +from CorpusCallosum.shape.mesh import CCMesh + +__all__ = [ + "CCContour", + "CCMesh", + "endpoint_heuristic", + "mesh", + "metrics", + "postprocessing", + "subsegment_contour", + "thickness", +] diff --git a/CorpusCallosum/shape/contour.py b/CorpusCallosum/shape/contour.py new file mode 100644 index 000000000..272b7a5fc --- /dev/null +++ b/CorpusCallosum/shape/contour.py @@ -0,0 +1,694 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module provides the ``CCContour`` class for reading, writing, and +manipulating 2D corpus callosum contours together with per-vertex thickness +values. Typical template outputs (from ``fastsurfer_cc.py --save_template``) +emit one set per slice: + +- ``contour_.txt``: CSV with header ``New contour, anterior_endpoint_idx=, posterior_endpoint_idx=

`` followed + by ``x,y`` rows. +- ``thickness_values_.txt``: CSV with header ``thickness`` and one value per contour vertex. +- ``thickness_measurement_points_.txt``: CSV with header ``vertex_idx`` listing the vertices where thickness was + measured. +""" + +import re +from pathlib import Path +from typing import Literal + +import lapy +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import scipy.interpolate +from scipy.ndimage import gaussian_filter1d + +import FastSurferCNN.utils.logging as logging +from CorpusCallosum.shape.endpoint_heuristic import smooth_contour +from CorpusCallosum.shape.thickness import cc_thickness, make_mesh_from_contour + +logger = logging.get_logger(__name__) + + +class CCContour: + """A class for representing and manipulating corpus callosum (CC) contours. + + This class provides functionality for manipulating and analyzing corpus callosum contours. + + Attributes + ---------- + contour : np.ndarray + Array of shape (N, 2) containing 2D contour points. + thickness_values : np.ndarray + Array of shape (N,) for thickness measurements for each contour point. + endpoint_idxs : tuple[int, int] + Tuple containing start and end indices for the contour. + + Examples + -------- + >>> from CorpusCallosum.shape.contour import CCContour + >>> + >>> contour = CCContour(contour_points, thickness_values, + >>> endpoint_idxs=(anterior_idx, posterior_idx), + >>> resolution=1.0) + >>> contour.fill_thickness_values() # interpolate missing values + >>> contour.smooth_contour(window_size=5) + >>> contour.save_contour("contour_0.txt") + >>> contour.save_thickness_values("thickness_values_0.txt") + >>> contour.save_thickness_measurement_points("thickness_measurement_points_0.txt") + """ + + def __init__( + self, + contour: np.ndarray[tuple[Literal["N", 2]], np.dtype[np.float_]], + thickness_values: np.ndarray[tuple[Literal["N"]], np.dtype[np.float_]], + endpoint_idxs: tuple[int, int] | None = None, + resolution: float = 1.0 + ): + """Initialize a CCContour object. + + Parameters + ---------- + contour : np.ndarray + Array of shape (N, 2) containing 2D contour points. + thickness_values : np.ndarray + Array of thickness measurements for each contour point. + endpoint_idxs : tuple[int, int], optional + Tuple containing start and end indices for the contour. + resolution : float, default=1.0 + The left-right spacing. + """ + self.contour = contour + if self.contour.shape[1] != 2: + raise ValueError(f"Contour must be a 2D array, but is {self.contour.shape}") + self.thickness_values = thickness_values + if self.contour.shape[0] != len(thickness_values): + raise ValueError( + f"Number of contour points ({self.contour.shape[0]}) does not match number of thickness values " + f"({len(thickness_values)})", + ) + # write vertex indices where thickness values are not nan + self.original_thickness_vertices = np.where(~np.isnan(thickness_values))[0] + self.resolution = resolution + + if endpoint_idxs is None: + self.endpoint_idxs = (0, len(contour) // 2) + else: + self.endpoint_idxs = endpoint_idxs + + def smooth_contour(self, window_size: int = 5) -> None: + """Smooth a contour using a moving average filter. + + Parameters + ---------- + window_size : int, default=5 + Size of the smoothing window. + + Notes + ----- + Uses smooth_contour from cc_endpoint_heuristic module to: + 1. Extract x and y coordinates. + 2. Apply moving average smoothing. + 3. Update contour with smoothed coordinates. + """ + x, y = self.contour.T + x, y = smooth_contour(x, y, window_size) + self.contour = np.array([x, y]).T + + def copy(self) -> "CCContour": + """Copy the contour. + """ + return CCContour(self.contour.copy(), self.thickness_values.copy(), self.endpoint_idxs, self.resolution) + + def get_contour_edge_lengths(self) -> np.ndarray: + """Get the lengths of the edges of a contour. + + Returns + ------- + np.ndarray + Array of edge lengths for the contour. + + Notes + ----- + Edge lengths are calculated as Euclidean distances between consecutive points + in the contour. + """ + edges = np.diff(self.contour, axis=0) + return np.sqrt(np.sum(edges**2, axis=1)) + + def create_levelpaths(self, + num_points: int, + update_data: bool = True + ) -> tuple[list[np.ndarray], list[float]]: + midline_len, thickness, curvature, midline_equi, \ + levelpaths, contour_with_thickness, endpoint_idxs = cc_thickness( + self.contour, + self.endpoint_idxs, + n_points=num_points, + ) + + if update_data: + self.contour = contour_with_thickness[:, :2] + self.thickness_values = contour_with_thickness[:,2] + self.original_thickness_vertices = np.where(~np.isnan(self.thickness_values))[0] + self.endpoint_idxs = endpoint_idxs + + return levelpaths, thickness + + def set_thickness_values(self, thickness_values: np.ndarray, use_measurement_points: bool = False) -> None: + """Set the thickness values for the contour. + This is useful to update the thickness values for specific plots. + + Parameters + ---------- + thickness_values : np.ndarray + Array of thickness values for the contour. + use_measurement_points : bool, optional + Whether to use the measurement points to set the thickness values, by default False. + """ + if use_measurement_points: + if len(thickness_values) == len(self.original_thickness_vertices): + self.thickness_values = np.full(len(self.contour), np.nan) + self.thickness_values[self.original_thickness_vertices] = thickness_values + else: + raise ValueError( + "Number of thickness values does not match number of measurement points " + f"{len(self.original_thickness_vertices)}.", + ) + else: + if len(thickness_values) != len(self.contour): + raise ValueError( + f"The number of thickness values does not match number of points in the contour " + f"{len(self.contour)}.", + ) + self.thickness_values = thickness_values + + def fill_thickness_values(self) -> None: + """Interpolate missing thickness values using weighted averaging. + + Notes + ----- + The function: + 1. Processes each contour with missing thickness values. + 2. For each missing value: + - Finds two closest points with known thickness. + - Calculates distances along contour. + - Computes weighted average based on inverse distance. + 3. Updates thickness values in place. + + The weights are calculated as inverse distances to ensure closer + points have more influence on the interpolated value. + + """ + thickness = self.thickness_values + edge_lengths = self.get_contour_edge_lengths() + + # Find indices of points with known thickness + known_idx = np.where(~np.isnan(thickness))[0] + + if len(known_idx) == 0: + logger.warning("No known thickness values; skipping interpolation") + return + if len(known_idx) == 1: + logger.warning("Only one known thickness value; skipping interpolation") + thickness[np.isnan(thickness)] = thickness[known_idx[0]] + self.thickness_values = thickness + return + + # For each point with unknown thickness + for j in range(len(thickness)): + if not np.isnan(thickness[j]): + continue + + # Find two closest points with known thickness + distances = np.zeros(len(known_idx)) + for k, idx in enumerate(known_idx): + # Calculate distance along contour by summing edge lengths + if idx > j: + distances[k] = np.sum(edge_lengths[j:idx]) + else: + distances[k] = np.sum(edge_lengths[idx:j]) + + # Get indices of two closest points + closest_indices = known_idx[np.argsort(distances)[:2]] + closest_distances = np.sort(distances)[:2] + + # Calculate weights based on inverse distance + weights = 1.0 / closest_distances + weights = weights / np.sum(weights) + + # Calculate weighted average thickness + thickness[j] = np.sum(weights * thickness[closest_indices]) + + self.thickness_values = thickness + + def smooth_thickness_values(self, iterations: int = 1) -> None: + """Smooth the thickness values using a Gaussian filter. + + Parameters + ---------- + iterations : int, optional + Number of smoothing iterations, by default 1. + + Notes + ----- + Applies Gaussian smoothing with sigma=5 to thickness values + for each slice that has measurements. + """ + for i in range(len(self.thickness_values)): + if self.thickness_values[i] is not None: + self.thickness_values[i] = gaussian_filter1d(self.thickness_values[i], sigma=5) + + def plot_contour(self, output_path: str | None = None) -> None: + """Plot a single contour with thickness values. + + Parameters + ---------- + output_path : str + Path where to save the plot. + + Notes + ----- + Creates a 2D visualization with: + - Points colored by thickness values. + - Gray points for missing thickness values. + - Connected contour line. + - Grid, labels, and legend. + """ + if output_path is not None: + self.__make_parent_folder(output_path) + + contour = self.contour + + plt.figure(figsize=(10, 10)) + # Get thickness values for this slice + thickness = self.thickness_values + + # Plot points with colors based on thickness + for i in range(len(contour)): + if np.isnan(thickness[i]): + plt.plot(contour[i, 0], contour[i, 1], "o", color="gray", markersize=1) + else: + # Map thickness to color from red to yellow + plt.plot( + contour[i, 0], + contour[i, 1], + "o", + color=plt.cm.YlOrRd(thickness[i] / np.nanmax(thickness)), + markersize=1, + ) + + # Connect points with lines + plt.plot(contour[:, 0], contour[:, 1], "-", color="black", alpha=0.3, label="Contour") + plt.axis("equal") + plt.xlabel("X") + plt.ylabel("Y") + plt.title("CC contour") + plt.legend() + plt.grid(True) + plt.tight_layout() + if output_path is not None: + plt.savefig(output_path, dpi=300) + else: + plt.show() + + + def plot_contour_colorfill( + self, + plot_values: np.ndarray, + title: str | None = None, + save_path: str | None = None, + colorbar: bool = True, + mode: str = "p-value", + ) -> matplotlib.figure.Figure: + """Plot a contour with levelset visualization. + + Creates a visualization of a contour with interpolated levelsets, useful for + analyzing the thickness distribution across the corpus callosum. + + Parameters + ---------- + plot_values : np.ndarray + Array of values to plot on CC from anterior to posterior (left to right in the plot). + title : str, optional + Title for the plot. + save_path : str, optional + Path to save the plot. If None, displays interactively. + colorbar : bool, default=True + Whether to show the colorbar. + mode : {"p-value", "icc", "thickness"}, default="p-value" + Mode of the plot. + + Returns + ------- + matplotlib.figure.Figure + The created figure object. + """ + plot_values = plot_values[::-1] # make sure values are plotted left to right (anterior to posterior) + + points, _ = make_mesh_from_contour(self.contour, max_volume=0.5, min_angle=25, verbose=False) + + # make points 3D by adding zero + points = np.column_stack([points, np.zeros(len(points))]) + + levelpaths, _ = self.create_levelpaths(num_points=len(plot_values)-1, update_data=False) + + outside_contour = self.contour.T + + # Create a grid of points covering the contour area with higher resolution + x_min, x_max = np.min(outside_contour[0]), np.max(outside_contour[0]) + y_min, y_max = np.min(outside_contour[1]), np.max(outside_contour[1]) + margin = 1 + resolution = 0.05 # Higher resolution for smoother interpolation + x_grid, y_grid = np.meshgrid( + np.arange(x_min - margin, x_max + margin, resolution), np.arange(y_min - margin, y_max + margin, resolution) + ) + + # Create a path from the outside contour + contour_path = matplotlib.path.Path(np.column_stack([outside_contour[0], outside_contour[1]])) + + # Check which points are inside the contour + points = np.column_stack([x_grid.flatten(), y_grid.flatten()]) + mask = contour_path.contains_points(points).reshape(x_grid.shape) + + # Collect all levelpath points and their corresponding values + # Extend each levelpath at both ends to improve extrapolation + all_level_points_x = [] + all_level_points_y = [] + all_level_values = [] + + for i, path in enumerate(levelpaths): + + # add third dimension to path + path = np.column_stack([path, np.zeros(len(path))]) + + if len(path) == 1: + all_level_points_x.append(path[0][0]) + all_level_points_y.append(path[0][1]) + all_level_values.append(plot_values[i]) + continue + + # make levelpath + path = lapy.TriaMesh._TriaMesh__resample_polygon(path, 1000) + + # Extend at the beginning: add point in direction opposite to first segment + first_segment = path[1] - path[0] + # standardize length of first segment + first_segment = first_segment / np.linalg.norm(first_segment) * 10 + extension_start = path[0] - first_segment + all_level_points_x.append(extension_start[0]) + all_level_points_y.append(extension_start[1]) + all_level_values.append(plot_values[i]) + + # Add original path points + for point in path: + all_level_points_x.append(point[0]) + all_level_points_y.append(point[1]) + all_level_values.append(plot_values[i]) + + # Extend at the end: add point in direction of last segment + last_segment = path[-1] - path[-2] + # standardize length of last segment + last_segment = last_segment / np.linalg.norm(last_segment) * 10 + extension_end = path[-1] + last_segment + all_level_points_x.append(extension_end[0]) + all_level_points_y.append(extension_end[1]) + all_level_values.append(plot_values[i]) + + # Convert to numpy arrays + all_level_points_x = np.array(all_level_points_x) + all_level_points_y = np.array(all_level_points_y) + all_level_values = np.array(all_level_values) + + # Use griddata to perform smooth interpolation - using 'linear' instead of 'cubic' + # and properly formatting the input points + grid_values = scipy.interpolate.griddata( + (all_level_points_x, all_level_points_y), all_level_values, (x_grid, y_grid), method="linear", fill_value=0, + ) + + # smooth the grid_values + grid_values = scipy.ndimage.gaussian_filter(grid_values, sigma=5, radius=5) + + # Apply the mask to only show values inside the contour + masked_values = np.where(mask, grid_values, np.nan) + + if mode == "p-value": + # Sample colormaps + colors1 = plt.cm.binary([0.4] * 128) + colors2 = plt.cm.hot(np.linspace(0.8, 0.1, 128)) + elif mode == "icc": + colors1 = plt.cm.Blues(np.linspace(0, 1, 128)) + colors2 = plt.cm.binary([0.4] * 128) + elif mode == "thickness": + # Blue to red colormap for thickness values + cmap = plt.cm.coolwarm + else: + raise ValueError(f"Invalid mode '{mode}'") + + # Combine the color samples for p-value and icc modes + if mode != "thickness": + colors = np.vstack((colors2, colors1)) + # Create a new colormap + cmap = matplotlib.colors.LinearSegmentedColormap.from_list("my_colormap", colors) + + # Plot CC contour with levelsets + fig = plt.figure(figsize=(10, 3)) + # Apply a 10-degree rotation to the entire plot + base = plt.gca().transData + transform = matplotlib.transforms.Affine2D().rotate_deg(10) + transform = transform + base + + # Plot the filled contour with interpolated colors + plt.imshow( + masked_values, + extent=(x_min - margin, x_max + margin, y_min - margin, y_max + margin), + origin="lower", + cmap=cmap, + alpha=1, + interpolation="bilinear", + vmin=0 if mode != "thickness" else np.nanmin(plot_values), + vmax=0.10 if mode == "p-value" else (1 if mode == "icc" else np.nanmax(plot_values)), + transform=transform, + ) + + plt.imshow( + masked_values, + extent=(x_min - margin, x_max + margin, y_min - margin, y_max + margin), + origin="lower", + cmap=cmap, + alpha=1, + interpolation="bilinear", + vmin=0 if mode != "thickness" else np.nanmin(plot_values), + vmax=0.10 if mode == "p-value" else (1 if mode == "icc" else np.nanmax(plot_values)), + # norm=LogNorm(vmin=1e-3, vmax=0.1), # Set minimum to avoid log(0) + transform=transform, + ) + + if colorbar: + # Add a colorbar + cbar = plt.colorbar(aspect=15) + if mode == "p-value": + cbar.ax.set_ylim(0.001, 0.054) + cbar.ax.set_yticks([0.0, 0.01, 0.02, 0.03, 0.04, 0.05]) + cbar.set_label("p-value (log scale)") + elif mode == "icc": + cbar.ax.set_ylim(0, 1) + cbar.ax.set_yticks([0, 0.25, 0.5, 0.75, 1]) + cbar.ax.set_label("Intraclass correlation coefficient") + elif mode == "thickness": + # Set limits based on actual thickness values + thickness_min = np.nanmin(plot_values) + thickness_max = np.nanmax(plot_values) + cbar.ax.set_ylim(thickness_min, thickness_max) + cbar.set_label("Thickness (mm)") + + # Plot the outside contour on top for clear boundary + plt.plot(outside_contour[0], outside_contour[1], "k-", linewidth=2, label="CC Contour", transform=transform) + + plt.axis("equal") + plt.title(title, fontsize=14, fontweight="bold") + # plt.legend(loc='best') + plt.gca().invert_xaxis() + plt.axis("off") + if save_path is not None: + self.__make_parent_folder(save_path) + plt.savefig(save_path, dpi=300) + else: + plt.show() + return fig + + @staticmethod + def __make_parent_folder(filename: Path | str) -> None: + """Create the parent folder for a file if it doesn't exist. + + Parameters + ---------- + filename : Path, str + Path to the file whose parent folder should be created. + + Notes + ----- + Creates parent directory with parents=False to avoid creating + multiple levels of directories unintentionally. + """ + Path(filename).parent.mkdir(parents=False, exist_ok=True) + + def save_contour(self, output_path: Path | str) -> None: + """Save the contours to a CSV file. + + Parameters + ---------- + output_path : Path, str + Path to save the CSV file. + + Notes + ----- + The function saves contours in CSV format with: + - Header: slice_idx,x,y. + - Special lines indicating new contours with endpoint indices. + - Each point gets its own row with slice index and coordinates. + """ + self.__make_parent_folder(output_path) + logger.info(f"Saving contours to CSV file: {output_path}") + with open(output_path, "w") as f: + + f.write( + f"New contour, anterior_endpoint_idx={self.endpoint_idxs[0]}, " + f"posterior_endpoint_idx={self.endpoint_idxs[1]}\n" + ) + f.write("x,y\n") + for point in self.contour: + f.write(f"{point[0]},{point[1]}\n") + + def load_contour(self, input_path: str) -> None: + """Load contour from a CSV file. + + Parameters + ---------- + input_path : str + Path to the CSV file containing the contours. + + Raises + ------ + ValueError + If the file format doesn't match expected structure. + + Notes + ----- + The function: + 1. Reads CSV file with format matching save_contours output. + 2. Processes special lines for endpoint indices. + 3. Reconstructs contours and endpoint indices for each slice. + 4. Converts lists to fixed-size arrays with None padding. + """ + current_points = [] + self.contours = [] + self.endpoint_idxs = [] + + with open(input_path) as f: + header = next(f).strip() + # Parse endpoint indices from header + anterior_match = re.search(r'anterior_endpoint_idx=(\d+)', header) + posterior_match = re.search(r'posterior_endpoint_idx=(\d+)', header) + assert anterior_match and posterior_match, "Header does not contain endpoint indices" + + anterior_idx = int(anterior_match.group(1)) + posterior_idx = int(posterior_match.group(1)) + self.endpoint_idxs = (anterior_idx, posterior_idx) + + # Skip column names + next(f) + + for line in f: + x, y = line.strip().split(",") + current_points.append([float(x), float(y)]) + self.contour = np.array(current_points) + + def save_thickness_values(self, output_path: Path | str) -> None: + """Save thickness values to a CSV file. + + Parameters + ---------- + output_path : Path, str + Path to save the CSV file. + + Notes + ----- + The function saves thickness values in CSV format with: + - Header: thickness. + - Each thickness value gets its own row with slice index. + - Skips slices with no thickness values. + """ + self.__make_parent_folder(output_path) + logger.info(f"Saving thickness data to CSV file: {output_path}") + with open(output_path, "w") as f: + f.write("thickness\n") + for value in self.thickness_values: + f.write(f"{value}\n") + + def load_thickness_values( + self, + input_path: str, + ) -> None: + """Load thickness values from a CSV file. + + Parameters + ---------- + input_path : str + Path to the CSV file containing thickness values. + original_thickness_vertices_path : str or None, optional + Path to a file containing the indices of vertices where thickness + was measured, by default None. + + Raises + ------ + ValueError + If number of thickness values doesn't match measurement points + or if number of slices is inconsistent. + + Notes + ----- + The function: + 1. Reads thickness values from CSV file. + 2. Groups values by slice index. + 3. Optionally associates values with specific vertices. + 4. Handles both full contour and profile measurements. + + + """ + data = np.loadtxt(input_path, delimiter=",", skiprows=1) + if data.ndim == 0: + values = np.array([float(data)]) + elif data.ndim == 1: + values = data.astype(float) + else: + raise ValueError("Thickness values file must contain a single column") + + if len(values) != len(self.contour): + if np.sum(~np.isnan(values)) == len(self.original_thickness_vertices): + new_values = np.full(len(self.contour), np.nan) + new_values[self.original_thickness_vertices] = values[~np.isnan(values)] + else: + raise ValueError( + f"Number of thickness values {len(values)} does not match number of points in the " + f"contour {len(self.contour)} and current number of measururement points " + f"{len(self.original_thickness_vertices)} does not match the number of set thickness values " + f"{np.sum(~np.isnan(values))}." + ) + else: + raise ValueError(f"Number of thickness values in {input_path} does not match the vertices of the path!") + + self.thickness_values = new_values diff --git a/CorpusCallosum/shape/endpoint_heuristic.py b/CorpusCallosum/shape/endpoint_heuristic.py new file mode 100644 index 000000000..3bbe16b39 --- /dev/null +++ b/CorpusCallosum/shape/endpoint_heuristic.py @@ -0,0 +1,299 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Literal, overload + +import lapy +import numpy as np +import scipy.ndimage +import skimage.measure +from scipy.ndimage import label + +from FastSurferCNN.utils import Vector2d, Mask2d + + +def smooth_contour(x: np.ndarray, y: np.ndarray, window_size: int) -> tuple[np.ndarray, np.ndarray]: + """Smooth a contour using a moving average filter. + + Parameters + ---------- + x : np.ndarray + X-coordinates of the contour points. + y : np.ndarray + Y-coordinates of the contour points. + window_size : int + Size of the smoothing window. Must be odd and > 2. + + Returns + ------- + tuple[np.ndarray, np.ndarray] + Smoothed x and y coordinates of the contour. + """ + # Ensure window_size is an integer + window_size = int(window_size) + + if window_size // 2 == 0: + raise ValueError(f"Smoothing window size of {window_size} is too small") + + # Ensure the window size is odd + if window_size % 2 == 0: + window_size += 1 + + # Create a padded version of the arrays to handle the edges + x_padded = np.pad(x, (window_size // 2, window_size // 2), mode="wrap") + y_padded = np.pad(y, (window_size // 2, window_size // 2), mode="wrap") + + # Apply moving average + x_smoothed = np.zeros_like(x) + y_smoothed = np.zeros_like(y) + + for i in range(len(x)): + x_smoothed[i] = np.mean(x_padded[i : i + window_size]) + y_smoothed[i] = np.mean(y_padded[i : i + window_size]) + + # remove padding + x_smoothed = x_smoothed[window_size // 2:-window_size // 2] + y_smoothed = y_smoothed[window_size // 2:-window_size // 2] + + return x_smoothed, y_smoothed + + +def connect_diagonally_connected_components(cc_mask: np.ndarray) -> None: + """Connect diagonally connected components in the CC mask. + + Parameters + ---------- + cc_mask : np.ndarray + Binary mask of the corpus callosum. + + Notes + ----- + Modifies the input mask in-place to connect diagonally connected components. + """ + + # Create padded mask to handle boundary conditions + padded_mask = np.pad(cc_mask, pad_width=1, mode='constant', constant_values=0) + + # Get center pixels and diagonal neighbors + center = padded_mask[1:-1, 1:-1] + + # Direct neighbors (4-connectivity) + left = padded_mask[1:-1, :-2] # left + right = padded_mask[1:-1, 2:] # right + up = padded_mask[:-2, 1:-1] # up + down = padded_mask[2:, 1:-1] # down + + # Diagonal neighbors + up_left = padded_mask[:-2, :-2] # up-left + up_right = padded_mask[:-2, 2:] # up-right + down_left = padded_mask[2:, :-2] # down-left + down_right = padded_mask[2:, 2:] # down-right + + potential_diagonal_gaps = (center == 0) & ( + ((up_left > 0) & ((right > 0) | (down > 0))) | + ((up_right > 0) & ((left > 0) | (down > 0))) | + ((down_left > 0) & ((right > 0) | (up > 0))) | + ((down_right > 0) & ((left > 0) | (up > 0))) + ) + + + # Get connected components before filling using 4-connectivity + # This way, diagonal-only connections are treated as separate components + structure_4conn = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]]) + _, num_components_before = label(cc_mask, structure=structure_4conn) + + # For each potential gap, check if filling it would reduce the number of components + connects_diagonals = np.zeros_like(potential_diagonal_gaps) + gap_positions = np.where(potential_diagonal_gaps) + + for i, j in zip(gap_positions[0], gap_positions[1], strict=True): + # Temporarily fill this gap + test_mask = cc_mask.copy() + test_mask[i, j] = 1 + + # Check connected components after filling + _, num_components_after = label(test_mask, structure=structure_4conn) + + # Only fill if it actually connects previously disconnected components + if num_components_after < num_components_before: + connects_diagonals[i, j] = True + + # Fill the identified diagonal gaps that actually improve connectivity + cc_mask[connects_diagonals] = 1 + + +def extract_cc_contour(cc_mask: np.ndarray, contour_smoothing: int = 5) -> np.ndarray: + """Extract the contour of the CC from the mask. + + Parameters + ---------- + cc_mask : np.ndarray + Binary mask of the corpus callosum. + contour_smoothing : int, default=5 + Window size for contour smoothing. + + Returns + ------- + np.ndarray + Array of shape (2, N) containing x,y coordinates of the contour points. + """ + # cc_mask_orig = cc_mask + cc_mask = cc_mask.copy() + + connect_diagonally_connected_components(cc_mask) + + contour = skimage.measure.find_contours(cc_mask, level=0.5)[0].T + contour = np.array(smooth_contour(contour[0], contour[1], contour_smoothing)) + + # plot contour + # import matplotlib.pyplot as plt + # fig, ax = plt.subplots(1,2,figsize=(10, 8)) + # ax[0].imshow(cc_mask_orig) + # ax[1].imshow(cc_mask) + # ax[0].plot(contour[1], contour[0], 'r-') + # ax[1].plot(contour[1], contour[0], 'r-') + # plt.show() + + return contour + + +@overload +def get_endpoints( + cc_mask: Mask2d, + ac_2d: Vector2d, + pc_2d: Vector2d, + resolution: tuple[float, float], + return_coordinates: Literal[True], + contour_smoothing: int = 5 +) -> tuple[np.ndarray, tuple[int, int], tuple[Vector2d, Vector2d]]: ... + + +@overload +def get_endpoints( + cc_mask: Mask2d, + ac_2d: Vector2d, + pc_2d: Vector2d, + resolution: tuple[float, float], + return_coordinates: Literal[False] = False, + contour_smoothing: int = 5 +) -> tuple[np.ndarray, tuple[int, int]]: ... + + +def get_endpoints( + cc_mask: Mask2d, + ac_2d: Vector2d, + pc_2d: Vector2d, + resolution: tuple[float, float], + return_coordinates: bool = False, + contour_smoothing: int = 5 +): + """Determine endpoints of CC by finding points closest to AC and PC. + + Parameters + ---------- + cc_mask : np.ndarray of shape (H, W) and type bool + Binary mask of the corpus callosum. + ac_2d : np.ndarray of shape (2,) and type float + 2D coordinates of the anterior commissure. + pc_2d : np.ndarray of shape (2,) and type float + 2D coordinates of the posterior commissure. + resolution : pair of floats + Inslice image resolution in mm (inferior/superior and anterior/posterior directions). + return_coordinates : bool, default=False + If True, return endpoint coordinates. + contour_smoothing : int, default=5 + Window size for contour smoothing. + + Returns + ------- + contour_rotated : np.ndarray + The contour rotated to AC-PC alignment. + anterior_posterior_point_indices : pair of ints + Indices of anterior and posterior points in the contour. + anterior_posterior_point_coordinates : tuple[np.ndarray, np.ndarray] + Only if return_coordinates is True: Coordinates of anterior and posterior points rotated to AP-PC alignment. + + Notes + ----- + Expects LIA orientation. + """ + image_size = cc_mask.shape + + # Calculate angle between AC-PC line and horizontal using numpy + ac_pc_vector = pc_2d - ac_2d + horizontal_vector = np.array([0, -20]) + # Calculate angle using dot product formula: cos(theta) = (a·b)/(|a||b|) + dot_product = np.dot(ac_pc_vector, horizontal_vector) + norms = np.linalg.norm(ac_pc_vector) * np.linalg.norm(horizontal_vector) + theta = np.arccos(dot_product / norms) + + # Convert symbolic theta to float and convert from radians to degrees + theta_degrees = theta * 180 / np.pi + rotated_cc_mask = scipy.ndimage.rotate(cc_mask, -theta_degrees, order=0, reshape=False) + + contour = extract_cc_contour(rotated_cc_mask, contour_smoothing) + + # rotate points around center + origin_point = np.array([image_size[0] // 2, image_size[1] // 2]) + + # Create rotation matrix for -theta + rot_matrix = np.array([[np.cos(-theta), -np.sin(-theta)], [np.sin(-theta), np.cos(-theta)]]) + + # Translate points to origin, rotate, then translate back + pc_centered = pc_2d - origin_point + ac_centered = ac_2d - origin_point + + rotated_pc_2d = (rot_matrix @ pc_centered) + origin_point + rotated_ac_2d = (rot_matrix @ ac_centered) + origin_point + + # Add z=0 coordinate to make 3D, then remove it after resampling + contour_3d = np.vstack([contour, np.zeros(contour.shape[1])]) + contour_3d = lapy.tria_mesh.TriaMesh._TriaMesh__resample_polygon(contour_3d.T, 701).T + contour = contour_3d[:2] + + contour = contour[:, :-1] + + rotated_ac_2d = np.array(rotated_ac_2d).astype(float) + rotated_pc_2d = np.array(rotated_pc_2d).astype(float) + + # move posterior commisure 5 mm posterior + # FIXME: why is the move 10mm inferior not commented? + # FIXME: multiplication means moving less for smaller voxels, why not division? + # changed to division, 5 mm / voxel size => number of voxels to move + rotated_pc_2d = rotated_pc_2d + np.array([10, -5]) / resolution + + # move anterior commisure 1.5 mm anterior + # FIXME: why does the documentation say 1.5mm when the code says 5mm? + rotated_ac_2d = rotated_ac_2d + np.array([0, 5]) / resolution + + # find point in contour closest to AC + ac_startpoint_idx = np.argmin(np.linalg.norm(contour - rotated_ac_2d[:, None], axis=0)) + + # find point in contour closest to PC + pc_startpoint_idx = np.argmin(np.linalg.norm(contour - rotated_pc_2d[:, None], axis=0)) + + # rotate startpoints to original orientation + origin_point = np.array(origin_point).astype(float) + # Create rotation matrix + rot_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) + + # Translate points to origin, rotate, then translate back + contour_centered = contour - origin_point[:, None] + contour_rotated = (rot_matrix @ contour_centered) + origin_point[:, None] + + if return_coordinates: + start_point_ac, start_point_pc = contour_rotated[:, [ac_startpoint_idx, pc_startpoint_idx]].T + + return contour_rotated, (ac_startpoint_idx, pc_startpoint_idx), (start_point_ac, start_point_pc) + else: + return contour_rotated, (ac_startpoint_idx, pc_startpoint_idx) diff --git a/CorpusCallosum/shape/mesh.py b/CorpusCallosum/shape/mesh.py new file mode 100644 index 000000000..74fdee1a7 --- /dev/null +++ b/CorpusCallosum/shape/mesh.py @@ -0,0 +1,845 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +from pathlib import Path +from typing import TypeVar + +import lapy +import nibabel as nib +import numpy as np +import plotly.graph_objects as go +from plotly.io import write_html as plotly_write_html +from scipy.ndimage import gaussian_filter1d + +import FastSurferCNN.utils.logging as logging +from CorpusCallosum.data.constants import FSAVERAGE_MIDDLE +from CorpusCallosum.shape.contour import CCContour +from CorpusCallosum.shape.thickness import make_mesh_from_contour +from FastSurferCNN.utils import nibabelImage +from FastSurferCNN.utils.common import suppress_stdout + +try: + from pyrr import Matrix44 + HAS_PYRR = True +except ImportError: + HAS_PYRR = False + class Matrix44(np.ndarray): + pass + +logger = logging.get_logger(__name__) + + + +def _create_cap( + points: np.ndarray, + trias: np.ndarray, + contour: CCContour, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Create a cap mesh for one end of the corpus callosum. + + Parameters + ---------- + points : np.ndarray + Array of shape (N, 2) containing mesh points + trias : np.ndarray + Array of shape (M, 3) containing triangle indices + contour : CCContour + CCContour object to create cap for + + Returns + ------- + tuple[np.ndarray, np.ndarray, np.ndarray] + - level_vertices : Array of vertices for the cap mesh + - level_faces : Array of face indices for the cap mesh + - level_colors : Array of thickness values for each vertex + + Notes + ----- + The function: + 1. Creates level paths using _create_levelpaths + 2. Resamples level paths to fixed number of points + 3. Creates triangles between consecutive level paths + 4. Smooths thickness values for visualization + """ + levelpaths, thickness_values = contour._create_levelpaths(points, trias) + + # Create mesh from level paths + level_vertices = [] + level_faces = [] + level_colors = [] + vertex_counter = 0 + sorted_thickness_values = np.array(thickness_values) + + # smooth thickness values + for _ in range(3): + sorted_thickness_values = gaussian_filter1d(sorted_thickness_values, sigma=5) + + NUM_LEVELPOINTS = 50 + + assert len(sorted_thickness_values) == len(levelpaths) + + # TODO: handle gap between first/last levelpath and contour + for idx, levelpath1 in enumerate(levelpaths): + levelpath1 = lapy.TriaMesh._TriaMesh__iterative_resample_polygon(levelpath1, NUM_LEVELPOINTS) + level_vertices.append(levelpath1) + level_colors.append(np.full((len(levelpath1)), sorted_thickness_values[idx])) + if idx + 1 < len(levelpaths): + levelpath2 = lapy.TriaMesh._TriaMesh__iterative_resample_polygon(levelpaths[idx + 1], NUM_LEVELPOINTS) + + # Create faces between the two paths by connecting vertices + faces_between = [] + i, j = 0, 0 + + while i < len(levelpath1) - 1 and j < len(levelpath2) - 1: + faces_between.append([i, i + 1, len(levelpath1) + j]) + faces_between.append([i + 1, len(levelpath1) + j + 1, len(levelpath1) + j]) + + i += 1 + j += 1 + + while i < len(levelpath1) - 1: + faces_between.append([i, i + 1, len(levelpath1) + j]) + i += 1 + + while j < len(levelpath2) - 1: + faces_between.append([i, len(levelpath1) + j + 1, len(levelpath1) + j]) + j += 1 + + if faces_between: + faces_between = np.array(faces_between) + level_faces.append(faces_between + vertex_counter) + + vertex_counter += len(levelpath1) + + # Convert to numpy arrays + level_vertices = np.vstack(level_vertices) + level_faces = np.vstack(level_faces) + level_colors = np.concatenate(level_colors) + + return level_vertices, level_faces, level_colors + + +def make_triangles_between_contours(contour1: np.ndarray, contour2: np.ndarray) -> np.ndarray: + """Create a triangular mesh between two contours using a robust method. + + Parameters + ---------- + contour1 : np.ndarray + First contour points of shape (N, 2). + contour2 : np.ndarray + Second contour points of shape (M, 2). + + Returns + ------- + np.ndarray + Array of triangle indices of shape (K, 3) where K is the number of triangles. + + Notes + ----- + The function: + 1. Finds closest point on contour2 to first point of contour1 + 2. Creates triangles by connecting corresponding points + 3. Handles contours with different numbers of points + 4. Creates two triangles to form a quad between each pair of points + """ + start_idx_c1 = 0 + # get closest point on contour2 to contour1[0] + start_idx_c2 = np.argmin(np.linalg.norm(contour2 - contour1[0], axis=1)) + + triangles = [] + n1 = len(contour1) + n2 = len(contour2) + + for i in range(n1): + # Current and next indices for contour1 + c1_curr = (start_idx_c1 + i) % n1 + c1_next = (start_idx_c1 + i + 1) % n1 + + # Current and next indices for contour2, offset by n1 to account for vertex stacking + c2_curr = ((start_idx_c2 + i) % n2) + n1 + c2_next = ((start_idx_c2 + i + 1) % n2) + n1 + + # Create two triangles to form a quad between the contours + triangles.append([c1_curr, c2_curr, c1_next]) + triangles.append([c2_curr, c2_next, c1_next]) + + return np.array(triangles) + + +Self = TypeVar('Self', bound='type[CCMesh]') + + +class CCMesh(lapy.TriaMesh): + """A class for representing and manipulating corpus callosum (CC) meshes. + + This class extends lapy.TriaMesh to provide specialized functionality for working with + corpus callosum meshes, including contour management, thickness measurements, and + visualization capabilities. + + The mesh can be constructed from a series of 2D contours representing slices of the + corpus callosum, with optional thickness measurements at various points along these + contours. + + Attributes + ---------- + v : np.ndarray + Vertex coordinates of the mesh. + t : np.ndarray + Triangle indices of the mesh. + mesh_vertex_colors : np.ndarray + Vertex values for each vertex (CC thickness values) + resolution : float + Spatial resolution of the mesh in millimeters. + """ + + def __init__(self, + vertices: list | np.ndarray, + faces: list | np.ndarray, + vertex_values: list | np.ndarray | None = None, + resolution: float = 1.0): + """Initialize a CC_Mesh object. + + Parameters + ---------- + vertices : list or numpy.ndarray + List of vertex coordinates or array of shape (N, 3). + faces : list or numpy.ndarray + List of face indices or array of shape (M, 3). + vertex_values : list or numpy.ndarray, optional + Vertex values for each vertex (CC thickness values) + resolution : float, optional + Spatial resolution of the mesh in millimeters, by default 1.0. + """ + super().__init__(np.vstack(vertices), np.vstack(faces)) + self.mesh_vertex_colors = vertex_values + self.resolution = resolution + + def plot_mesh( + self, + output_path: Path | str | None = None, + colormap: str = "red_to_yellow", + thickness_overlay: bool = True, + show_grid: bool = False, + color_range: tuple[float, float] | None = None, + legend: str = "", + threshold: tuple[float, float] | None = None, + ): + """Plot the mesh using Plotly for better performance and interactivity. + + Creates an interactive 3D visualization of the mesh with optional features like + thickness overlay, contour display, and grid visualization. + + Parameters + ---------- + output_path : Path, str, optional + Path to save the plot. If None, displays the plot interactively. + colormap : str, optional + Which colormap to use, by default "red_to_yellow". + Options: + - "red_to_blue": Red -> Orange -> Grey -> Light Blue -> Blue + - "red_to_yellow": Red -> Yellow -> Light Blue -> Blue + - "yellow_to_red": Yellow -> Light Blue -> Blue -> Red + - "blue_to_red": Blue -> Light Blue -> Grey -> Orange -> Red + thickness_overlay : bool, optional + Whether to overlay thickness values on the mesh, by default True. + show_contours : bool, optional + Whether to show the contours, by default False. + show_grid : bool, optional + Whether to show the grid, by default False. + color_range : tuple[float, float], optional + Fixed range (min, max) for the colorbar, by default None. + show_mesh_edges : bool, optional + Whether to show the mesh edges, by default False. + legend : str, optional + Legend text for the colorbar, by default "". + threshold : tuple[float, float], optional + Values between these thresholds will be shown in grey, by default None. + + Notes + ----- + The plot can be saved to an HTML file or displayed in a web browser. + """ + assert self.v is not None and self.t is not None, "Mesh has not been created yet" + + if len(self.v) == 0: + logger.warning("Warning: No vertices in mesh to plot") + return + + if len(self.t) == 0: + logger.warning("Warning: No faces in mesh to plot") + return + + # Define available colormaps + colormaps = { + "red_to_blue": [ + [0.0, "rgb(255,0,0)"], # Bright red + [0.25, "rgb(255,165,0)"], # Light orange + [0.5, "rgb(150,150,150)"], # Dark grey for middle + [0.75, "rgb(173,216,230)"], # Light blue + [1.0, "rgb(0,0,255)"], # Bright blue + ], + "blue_to_red": [ + [0.0, "rgb(0,0,255)"], # Bright blue + [0.25, "rgb(173,216,230)"], # Light blue + [0.5, "rgb(150,150,150)"], # Dark grey for middle + [0.75, "rgb(255,165,0)"], # Light orange + [1.0, "rgb(255,0,0)"], # Bright red + ], + "red_to_yellow": [ + [0.0, "rgb(255,0,0)"], # Bright red + [0.33, "rgb(255,85,0)"], # Red-orange + [0.66, "rgb(255,170,0)"], # Orange + [1.0, "rgb(255,255,0)"], # Yellow + ], + "yellow_to_red": [ + [0.0, "rgb(255,255,0)"], # Yellow + [0.33, "rgb(255,170,0)"], # Orange + [0.66, "rgb(255,85,0)"], # Red-orange + [1.0, "rgb(255,0,0)"], # Bright red + ], + } + + # Select the colormap + if colormap not in colormaps: + logger.warning(f"Warning: Unknown colormap '{colormap}'. Using 'red_to_blue' instead.") + colormap = "red_to_blue" + + selected_colormap = colormaps[colormap] + + # If threshold is provided, modify the colormap to include grey region + if threshold is not None and thickness_overlay and hasattr(self, "mesh_vertex_colors"): + data_min = np.min(self.mesh_vertex_colors) if color_range is None else color_range[0] + data_max = np.max(self.mesh_vertex_colors) if color_range is None else color_range[1] + data_range = data_max - data_min + + # Calculate normalized threshold positions + thresh_low = (threshold[0] - data_min) / data_range + thresh_high = (threshold[1] - data_min) / data_range + + # Ensure thresholds are within [0,1] + thresh_low = max(0, min(1, thresh_low)) + thresh_high = max(0, min(1, thresh_high)) + + # Create new colormap with grey threshold region + grey_color = "rgb(150,150,150)" # Medium grey + new_colormap = [] + + # Add colors before threshold with adjusted positions + if thresh_low > 0: + for pos, color in selected_colormap: + if pos < 1: # Only use positions less than 1 + new_pos = pos * thresh_low + new_colormap.append([new_pos, color]) + + # Add threshold boundaries with grey + new_colormap.extend([[thresh_low, grey_color], [thresh_high, grey_color]]) + + # Add colors after threshold with adjusted positions + if thresh_high < 1: + remaining_range = 1 - thresh_high + for pos, color in selected_colormap: + if pos > 0: # Only use positions greater than 0 + new_pos = thresh_high + pos * remaining_range + if new_pos <= 1: # Ensure we don't exceed 1 + new_colormap.append([new_pos, color]) + + selected_colormap = new_colormap + + # Calculate data ranges and center + xyz_min = self.v.min(axis=0) + xyz_max = self.v.max(axis=0) + xyz_range = xyz_max - xyz_min + max_range = xyz_range.max() + center = (xyz_max + xyz_min) / 2 + + # Create mesh plot + fig = go.Figure() + + # Add the mesh as a surface + mesh_args = { + "x": self.v[:, 0], + "y": self.v[:, 1], + "z": self.v[:, 2], + "i": self.t[:, 0], # First vertex of each triangle + "j": self.t[:, 1], # Second vertex + "k": self.t[:, 2], # Third vertex + "hoverinfo": "skip", + "lighting": dict(ambient=0.9, diffuse=0.1, roughness=0.3), + } + + if thickness_overlay and hasattr(self, "mesh_vertex_colors"): + mesh_args.update( + { + "intensity": self.mesh_vertex_colors, # Add intensity values for colorbar + "showscale": True, + "colorbar": dict( + title=dict( + text=legend, + font=dict(size=35, color="white"), # Increase title font size and make white + side="right", # Place title on right side + ), + len=0.55, # Make colorbar shorter + thickness=35, # Make colorbar wider + tickfont=dict(size=30, color="white"), # Increase tick font size and make white + tickformat=".1f", # Show one decimal place + ), + "opacity": 1, + "colorscale": selected_colormap, + } + ) + + # Set the colorbar range + if color_range is not None: + mesh_args["cmin"] = color_range[0] + mesh_args["cmax"] = color_range[1] + else: + # Use data range if no explicit range is provided + mesh_args["cmin"] = np.min(self.mesh_vertex_colors) + mesh_args["cmax"] = np.max(self.mesh_vertex_colors) + else: + mesh_args["color"] = "lightsteelblue" + + fig.add_trace(go.Mesh3d(**mesh_args)) + + # Calculate axis ranges to maintain equal aspect ratio + ranges = [] + for i in range(3): + axis_range = [center[i] - max_range / 2, center[i] + max_range / 2] + ranges.append(axis_range) + + # Configure axes and grid visibility + axis_config = dict( + showgrid=show_grid, + showline=show_grid, + zeroline=show_grid, + showbackground=show_grid, + showticklabels=show_grid, + gridcolor="white", + tickfont=dict(color="white"), + title=dict(font=dict(color="white")), + ) + + fig.update_layout( + scene=dict( + xaxis=dict(range=ranges[0], **{**axis_config, "title": "AP" if show_grid else ""}), + yaxis=dict(range=ranges[1], **{**axis_config, "title": "SI" if show_grid else ""}), + zaxis=dict(range=ranges[2], **{**axis_config, "title": "LR" if show_grid else ""}), + camera=dict(eye=dict(x=1.5, y=1.5, z=1), up=dict(x=0, y=0, z=1)), + aspectmode="cube", # Force equal aspect ratio + aspectratio=dict(x=1, y=1, z=1), + bgcolor="black", + dragmode="orbit", # Enable orbital rotation by default + ), + showlegend=False, + margin=dict(l=0, r=100, t=0, b=0), # Increased right margin for colorbar + paper_bgcolor="black", + plot_bgcolor="black", + ) + + if output_path is not None: + self.__make_parent_folder(output_path) + plotly_write_html(fig, output_path, include_plotlyjs="cdn") # Save as interactive HTML + else: + # For non-interactive display, save to a temporary HTML and open in browser + import tempfile + import webbrowser + + temp_path = Path(tempfile.gettempdir()) / "cc_mesh_plot.html" + plotly_write_html(fig, temp_path, include_plotlyjs="cdn") + webbrowser.open(f"file://{temp_path}") + + + @staticmethod + def __create_cc_viewmat() -> "Matrix44": + """Create the view matrix for a nice view of the corpus callosum. + + Returns + ------- + Matrix44 + 4x4 view matrix that provides a standard view of the corpus callosum (from pyrr). + + Notes + ----- + The function: + 1. Creates a base view matrix looking from the left with top up + 2. Applies a series of rotations: + - -10 degrees around x-axis + - 35 degrees around y-axis + - -8 degrees around z-axis + 3. Adds a small translation for better centering + """ + + if not HAS_PYRR: + raise ImportError("Pyrr not installed, install pyrr with `pip install pyrr`.") + + viewLeft = np.array([[0, 0, -1, 0], [-1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) # left w top up // right + transl = Matrix44.from_translation((0, 0, 0.4)) + viewmat = transl * viewLeft + + # rotate 10 degrees around x axis + rot = Matrix44.from_x_rotation(np.deg2rad(-10)) + viewmat = viewmat * rot + + # rotate 35 degrees around y axis + rot = Matrix44.from_y_rotation(np.deg2rad(35)) + viewmat = viewmat * rot + + # rotate 10 degrees around z axis + rot = Matrix44.from_z_rotation(np.deg2rad(-8)) + viewmat = viewmat * rot + + return viewmat + + def snap_cc_picture( + self, + output_path: Path | str, + fssurf_file: Path | str | None = None, + overlay_file: Path | str | None = None, + ref_image: Path | str | nibabelImage | None = None, + ) -> None: + """Snap a picture of the corpus callosum mesh. + + Parameters + ---------- + output_path : Path, str + Path where to save the snapshot image. + fssurf_file : Path, str, optional + Path to a FreeSurfer surface file to use for the snapshot. + If None, the mesh is saved to a temporary file. + overlay_file : Path, str, optional + Path to a FreeSurfer overlay file to use for the snapshot. + If None, the mesh is saved to a temporary file. + ref_image : Path, str, optional + Path to reference image to use for tkr creation. If None, ignores the file for saving. + + Raises + ------ + Warning + If the mesh has no faces and cannot create a snapshot. + + Notes + ----- + The function: + 1. Creates temporary files for mesh and overlay data if needed. + 2. Uses whippersnappy to create a snapshot with: + - Custom view matrix for standard orientation. + - Ambient lighting and colorbar settings. + - Thickness overlay if available. + 3. Cleans up temporary files after use. + """ + try: + from whippersnappy.core import snap1 + except ImportError: + # whippersnappy not installed + raise RuntimeError( + "The snap_cc_picture method of CCMesh requires whippersnappy, but whippersnappy was not found. " + "Please install whippersnappy!" + ) from None + self.__make_parent_folder(output_path) + # Skip snapshot if there are no faces + if len(self.t) == 0: + logger.warning("Cannot create snapshot - no faces in mesh") + return + + # create temp file + if fssurf_file: + fssurf_file = Path(fssurf_file) + else: + fssurf_file = tempfile.NamedTemporaryFile(suffix=".fssurf", delete=True).name + self.write_fssurf(fssurf_file, image=str(ref_image) if isinstance(ref_image, Path) else ref_image) + + if overlay_file: + overlay_file = Path(overlay_file) + else: + overlay_file = tempfile.NamedTemporaryFile(suffix=".w", delete=True).name + # Write thickness values in FreeSurfer '*.w' overlay format + self.write_morph_data(overlay_file) + + try: + with suppress_stdout(): + snap1( + fssurf_file, + overlaypath=overlay_file, + view=None, + viewmat=self.__create_cc_viewmat(), + width=3 * 500, + height=3 * 300, + outpath=output_path, + ambient=0.6, + colorbar_scale=0.5, + colorbar_y=0.88, + colorbar_x=0.19, + brain_scale=2.1, + fthresh=0, + caption="Corpus Callosum thickness (mm)", + caption_y=0.85, + caption_x=0.17, + caption_scale=0.5, + ) + except Exception as e: + raise e from None + + if fssurf_file and hasattr(fssurf_file, "close"): + fssurf_file.close() + if overlay_file and hasattr(overlay_file, "close"): + overlay_file.close() + + def smooth_(self, iterations: int = 1) -> None: + """Smooth the mesh while preserving the z-coordinates. + + Parameters + ---------- + iterations : int, optional + Number of smoothing iterations, by default 1. + + Notes + ----- + The function: + 1. Stores original z-coordinates. + 2. Applies Laplacian smoothing to x and y coordinates. + 3. Restores original z-coordinates to maintain slice structure. + """ + z_values = self.v[:, 2] + super().smooth_(iterations) + self.v[:, 2] = z_values + + + @staticmethod + def __make_parent_folder(filename: Path | str) -> None: + """Create the parent folder for a file if it doesn't exist. + + Parameters + ---------- + filename : Path, str + Path to the file whose parent folder should be created. + + Notes + ----- + Creates parent directory with parents=False to avoid creating + multiple levels of directories unintentionally. + """ + Path(filename).parent.mkdir(parents=False, exist_ok=True) + + def to_fs_coordinates( + self, + lr_offset: float, + ) -> "CCMesh": + """Convert mesh coordinates to FreeSurfer coordinate system. + + Parameters + ---------- + lr_offset : float + Voxel offset to apply before transformation, this should be often `FSAVERAGE_MIDDLE / vox_size_in_lr`. + + Returns + ------- + CCMesh + A CCMesh object with vertices reoriented to FreeSurfer coordinates. + + Notes + ----- + Mesh coordinates are in ASR (Anterior-Superior-Right) orientation, with the coordinate system origin on + *the* midslice. The function transforms from midslice ASR to LIA vox coordinates. + """ + from copy import copy + new_object = copy(self) + + asrvox_midslice2orig_vox2vox = np.eye(4) + # to LSA + asrvox_midslice2orig_vox2vox[:, [0, 2]] = asrvox_midslice2orig_vox2vox[:, [2, 0]] + # center LR + asrvox_midslice2orig_vox2vox[0, 3] = lr_offset + # flip SI + asrvox_midslice2orig_vox2vox[:, 1] *= -1 + + # to LSA + # new_object.v = new_object.v[:, [2, 1, 0]] + # to voxel + # FIXME: why are the vertex positions multiplied by voxel size here? + # removed => for center LR, now dividing by resolution => convert fsaverage middle from mm to vox + # => remove the conversion back to mm in the end + # all other operations are independent of order of operations (distributive) + # v_vox /= vox_size[0] + # center LR + # new_object.v[:, 0] += FSAVERAGE_MIDDLE / self.resolution + # flip SI + # new_object.v[:, 1] = -new_object.v[:, 1] + + #v_vox_test = np.round(v_vox).astype(int) + + # tkrRAS = Torig*[C R S 1]' + # Torig: mri_info --vox2ras-tkr orig.mgz + # https://surfer.nmr.mgh.harvard.edu/fswiki/CoordinateSystems + + v_vox = np.concatenate([self.v, np.ones((self.v.shape[0], 1))], axis=1) + new_object.v = (v_vox @ asrvox_midslice2orig_vox2vox.T)[:, :3] + # new_object.v = (vox2ras_tkr @ np.concatenate([self.v, np.ones((self.v.shape[0], 1))], axis=1).T).T[:, :3] + return new_object + + def write_fssurf(self, filename: Path | str, image: str | object | None = None) -> None: + """Save as Freesurfer Surface Geometry file (wrap Nibabel). + + Parameters + ---------- + filename : str + Filename to save to. + image : str, object, None + Path to image or nibabel image object. If specified, the vertices + are assumed to be in voxel coordinates and are converted + to surface RAS (tkr) coordinates before saving. + The expected order of coordinates is (x, y, z) matching + the image voxel indices. + + Notes + ----- + Also creates parent directory if needed before writing the file. + """ + self.__make_parent_folder(filename) + return super().write_fssurf(filename, image=image) + + def write_morph_data(self, filename: Path | str) -> None: + """Write the thickness values as a FreeSurfer overlay file. + + Parameters + ---------- + filename : Path, str + Path where to save the overlay file. + + Notes + ----- + Creates parent directory if needed before writing the file. + """ + self.__make_parent_folder(filename) + return nib.freesurfer.write_morph_data(filename, self.mesh_vertex_colors) + + @classmethod + def from_contours( + cls: Self, + contours: list[CCContour], + lr_center: float = 0, + closed: bool = False, + smooth: int = 0, + ) -> Self: + """Create a surface mesh by triangulating between consecutive contours. + + Parameters + ---------- + contours : list[CCContour] + List of CCContour objects to create mesh from. + lr_center : float, default=0 + Center position in the left-right axis. + closed : bool, default=False + Whether to create a closed mesh by adding caps. + smooth : int, default=0 + Number of smoothing iterations to apply. + + Returns + ------- + CCMesh + The joined CCMesh object. + + Raises + ------ + Warning + If no valid contours are found. + + Notes + ----- + The function: + 1. Filters out None contours. + 2. Calculates z-coordinates for each slice. + 3. Creates triangles between adjacent contours. + 4. Optionally: + - Creates caps at both ends. + - Applies smoothing. + - Colors caps based on thickness values. + + """ + + # Check that all contours have the same resolution + resolution = contours[0].resolution + for idx, contour in enumerate(contours[1:], start=1): + if not np.isclose(contour.resolution, resolution): + raise ValueError( + f"All contours must have the same resolution. " + f"Expected {resolution}, but contour at index {idx} has {contour.resolution}." + ) + + # Calculate z coordinates for each slice + z_coordinates = (np.arange(len(contours)) - len(contours) // 2) * contours[0].resolution + lr_center + + # Build vertices list with z-coordinates + vertices = [] + faces = [] + vertex_start_indices = [] # Track starting index for each contour + current_index = 0 + + for i, contour in enumerate(contours): + vertex_start_indices.append(current_index) + vertices.append(np.hstack([contour.contour, np.full((len(contour.contour), 1), z_coordinates[i])])) + + # Check if there's a next valid contour to connect to + if i + 1 < len(contours): + contour2 = contours[i + 1] + faces_between = make_triangles_between_contours(contour.contour, contour2.contour) + faces.append(faces_between + current_index) + + current_index += len(contour.contour) + + vertex_values = np.concatenate([contour.thickness_values for contour in contours]) + + if smooth > 0: + tmp_mesh = CCMesh(vertices, faces, vertex_values=vertex_values) + tmp_mesh.smooth_(smooth) + vertices = tmp_mesh.v + faces = tmp_mesh.t + vertex_values = tmp_mesh.mesh_vertex_colors + + if closed: + # Close the mesh by creating caps on both ends + # Left cap (first slice) - use counterclockwise orientation + left_side_points, left_side_trias = make_mesh_from_contour(vertices[: vertex_start_indices[1]][..., :2]) + left_side_points = np.hstack([left_side_points, np.full((len(left_side_points), 1), z_coordinates[0])]) + + # Right cap (last slice) - reverse points for proper orientation + right_side_points, right_side_trias = make_mesh_from_contour(vertices[vertex_start_indices[-1]:][..., :2]) + right_side_points = np.hstack([right_side_points, np.full((len(right_side_points), 1), z_coordinates[-1])]) + + #FIXME: Can we remove this if-statement? + color_sides = True + if color_sides: + left_side_points, left_side_trias, left_side_colors = _create_cap( + left_side_points, left_side_trias, contours[0] + ) + right_side_points, right_side_trias, right_side_colors = _create_cap( + right_side_points, right_side_trias, contours[-1] + ) + # reverse right side trias + right_side_trias = right_side_trias[:, ::-1] + else: + left_side_colors, right_side_colors = [], [] + + left_side_trias = left_side_trias + current_index + current_index += len(left_side_points) + + right_side_trias = right_side_trias + current_index + current_index += len(right_side_points) + + # FIXME: should this not be a concatenate statements? + vertices = [vertices, left_side_points, right_side_points] + faces = [faces, left_side_trias, right_side_trias] + vertex_values = [vertex_values, left_side_colors, right_side_colors] + + return cls(vertices, faces, vertex_values=vertex_values, resolution=resolution) diff --git a/CorpusCallosum/shape/metrics.py b/CorpusCallosum/shape/metrics.py new file mode 100644 index 000000000..921c819f4 --- /dev/null +++ b/CorpusCallosum/shape/metrics.py @@ -0,0 +1,328 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +import FastSurferCNN.utils.logging as logging + +logger = logging.get_logger(__name__) + + +# TODO: we could make this more robust by standardizing orientation with AC/PC and smoothing the contour + +def _line_segment_intersection( + line_point: np.ndarray, + line_dir: np.ndarray, + seg_start: np.ndarray, + seg_end: np.ndarray, + tol: float = 1e-10, +) -> np.ndarray | None: + """Compute intersection between an infinite line and a line segment. + + Uses the parametric form: + - Line: P = line_point + t * line_dir + - Segment: Q = seg_start + s * (seg_end - seg_start), where s ∈ [0, 1] + + Parameters + ---------- + line_point : np.ndarray + A point on the infinite line, shape (2,). + line_dir : np.ndarray + Direction vector of the line, shape (2,). + seg_start : np.ndarray + Start point of the segment, shape (2,). + seg_end : np.ndarray + End point of the segment, shape (2,). + tol : float + Tolerance for numerical comparisons. + + Returns + ------- + np.ndarray | None + Intersection point as shape (2,) array, or None if no intersection. + """ + seg_dir = seg_end - seg_start + + # Build the linear system: [line_dir, -seg_dir] @ [t, s].T = seg_start - line_point + # Matrix A = [[line_dir[0], -seg_dir[0]], [line_dir[1], -seg_dir[1]]] + A = np.array([[line_dir[0], -seg_dir[0]], + [line_dir[1], -seg_dir[1]]]) + b = seg_start - line_point + + # Check if lines are parallel (determinant ≈ 0) + det = A[0, 0] * A[1, 1] - A[0, 1] * A[1, 0] + if abs(det) < tol: + return None + + # Solve for t and s using Cramer's rule (faster than linalg.solve for 2x2) + t = (b[0] * A[1, 1] - b[1] * A[0, 1]) / det + s = (A[0, 0] * b[1] - A[1, 0] * b[0]) / det + + # Check if intersection is within the segment [0, 1] + if -tol <= s <= 1.0 + tol: + return line_point + t * line_dir + return None + + +def get_intersections( + contour: np.ndarray, start_point: np.ndarray, direction: np.ndarray +) -> np.ndarray: + """Find intersection points between an infinite line and a closed contour. + + Parameters + ---------- + contour : np.ndarray + Array of shape (2, N) containing contour points in ACPC space. + start_point : np.ndarray + A point on the line, shape (2,). + direction : np.ndarray + Direction vector of the line, shape (2,). + + Returns + ------- + np.ndarray + Array of shape (M, 2) containing intersection points, sorted along the direction. + """ + start_point = np.asarray(start_point, dtype=float) + direction = np.asarray(direction, dtype=float) + + # Normalize direction + dir_norm = np.linalg.norm(direction) + if dir_norm < 1e-10: + return np.empty((0, 2)) + direction = direction / dir_norm + + n_points = contour.shape[1] + intersections = [] + + # Check intersection with each segment of the closed contour + for i in range(n_points): + seg_start = contour[:, i] + seg_end = contour[:, (i + 1) % n_points] # Wrap around to close the contour + + intersection = _line_segment_intersection( + start_point, direction, seg_start, seg_end + ) + if intersection is not None: + intersections.append(intersection) + + if not intersections: + return np.empty((0, 2)) + + points = np.array(intersections) + + # Remove duplicate points (can occur at contour vertices) + if len(points) > 1: + # Project onto line direction and find unique points + projections = np.dot(points - start_point, direction) + # Sort and remove duplicates within tolerance + sorted_idx = np.argsort(projections) + points = points[sorted_idx] + projections = projections[sorted_idx] + + # Keep points that are sufficiently far apart + mask = np.ones(len(points), dtype=bool) + for i in range(1, len(points)): + if abs(projections[i] - projections[i - 1]) < 1e-8: + mask[i] = False + points = points[mask] + + return points + + +def calculate_cc_index(cc_contour: np.ndarray, plot: bool = False) -> float: + """Calculate CC index based on three thickness measurements. + + The AP line intersects the contour 4 times. The measurements are: + - Anterior thickness: distance between intersection points 1 and 2 + - Posterior thickness: distance between intersection points 3 and 4 + - Middle thickness: perpendicular line through midpoint of AP line + + The CC index is: (anterior + posterior + middle) / AP_length + + Parameters + ---------- + cc_contour : np.ndarray + Array of shape (2, N) containing contour points in ACPC space. + plot : bool, optional + Whether to generate a debug plot. Default is True. + + Returns + ------- + cc_index : float + The CC index, which is the sum of thicknesses at three measurement points divided by AP length. + """ + # Get anterior and posterior points (extremes along x-axis) + anterior_idx = np.argmin(cc_contour[0]) # Leftmost point + posterior_idx = np.argmax(cc_contour[0]) # Rightmost point + + anterior_pt = cc_contour[:, anterior_idx] + posterior_pt = cc_contour[:, posterior_idx] + + # AP line vector and properties + ap_vector = posterior_pt - anterior_pt + ap_length = np.linalg.norm(ap_vector) + ap_unit = ap_vector / ap_length + + # Perpendicular direction (90 degrees rotated) + perp_unit = np.array([-ap_unit[1], ap_unit[0]]) + + # Find where AP line intersects the contour (should be 4 points) + ap_intersections = get_intersections( + contour=cc_contour, start_point=anterior_pt, direction=ap_unit + ) + + if len(ap_intersections) != 4: + logger.error( + f"AP line should intersect contour exactly 4 times, " + f"but found {len(ap_intersections)} intersections" + ) + return 0.0 + + # Measurement 1: anterior thickness (between intersection points 1 and 2) + anterior_thickness = np.linalg.norm(ap_intersections[0] - ap_intersections[1]) + + # Measurement 2: posterior thickness (between intersection points 3 and 4) + posterior_thickness = np.linalg.norm(ap_intersections[2] - ap_intersections[3]) + + # AP distance is between outermost intersection points (1 and 4) + ap_distance = np.linalg.norm(ap_intersections[0] - ap_intersections[3]) + + # Midpoint of AP line (between points 1 and 4, or between anterior and posterior extremes) + midpoint = (ap_intersections[0] + ap_intersections[3]) / 2 + + # Measurement 3: perpendicular line through midpoint + middle_intersections = get_intersections( + contour=cc_contour, start_point=midpoint, direction=perp_unit + ) + + middle_thickness = np.linalg.norm(middle_intersections[0] - middle_intersections[-1]) + + # Calculate CC index + cc_index = (anterior_thickness + posterior_thickness + middle_thickness) / ap_distance + + if plot: + import matplotlib.pyplot as plt + + fig, ax = plt.subplots(figsize=(8, 6)) + plot_cc_index_calculation( + ax, + cc_contour, + anterior_idx, + posterior_idx, + ap_intersections, + middle_intersections, + midpoint, + ) + ax.legend() + plt.show() + + return cc_index + + +def plot_cc_index_calculation( + ax, + cc_contour: np.ndarray, + anterior_idx: int, + posterior_idx: int, + ap_intersections: np.ndarray, + middle_intersections: np.ndarray, + midpoint: np.ndarray, +) -> None: + """Plot the CC index measurements. + + Parameters + ---------- + ax : matplotlib.axes.Axes + The axes to plot on. + cc_contour : np.ndarray + Array of shape (2, N) containing contour points in ACPC space. + anterior_idx : int + Index of the anterior point on the contour. + posterior_idx : int + Index of the posterior point on the contour. + ap_intersections : np.ndarray + Array of shape (4, 2) containing the 4 intersection points of the AP line with the contour. + middle_intersections : np.ndarray + Array of shape (2, 2) containing middle perpendicular intersection points. + midpoint : np.ndarray + Array of shape (2,) containing the midpoint of the AP line. + """ + from matplotlib.patches import PathPatch + from matplotlib.path import Path + + # Plot the CC contour (closed) + ax.plot(cc_contour[0], cc_contour[1], "k-", linewidth=1) + ax.plot( + [cc_contour[0, -1], cc_contour[0, 0]], + [cc_contour[1, -1], cc_contour[1, 0]], + "k-", + linewidth=1, + ) + + # Plot AP line through all 4 intersection points + ax.plot( + [ap_intersections[0, 0], ap_intersections[3, 0]], + [ap_intersections[0, 1], ap_intersections[3, 1]], + "r--", + linewidth=1, + label="AP line", + ) + + # Mark all 4 intersection points + for i, pt in enumerate(ap_intersections): + ax.scatter([pt[0]], [pt[1]], s=40, zorder=5) + ax.annotate(f"{i+1}", (pt[0], pt[1]), textcoords="offset points", + xytext=(5, 5), fontsize=10) + + # Plot anterior thickness (points 1-2) + ax.plot( + [ap_intersections[0, 0], ap_intersections[1, 0]], + [ap_intersections[0, 1], ap_intersections[1, 1]], + "b-", + linewidth=3, + label="Anterior thickness (1-2)", + ) + + # Plot posterior thickness (points 3-4) + ax.plot( + [ap_intersections[2, 0], ap_intersections[3, 0]], + [ap_intersections[2, 1], ap_intersections[3, 1]], + "c-", + linewidth=3, + label="Posterior thickness (3-4)", + ) + + # Plot middle thickness measurement (perpendicular) + ax.plot( + [middle_intersections[0, 0], middle_intersections[-1, 0]], + [middle_intersections[0, 1], middle_intersections[-1, 1]], + "g-", + linewidth=3, + label="Middle thickness", + ) + + # Mark midpoint + ax.scatter([midpoint[0]], [midpoint[1]], color="red", s=50, zorder=5, + marker="x", label="Midpoint") + + ax.set_aspect("equal") + + # Fill the contour with gray + contour_path = Path(cc_contour.T) + patch = PathPatch(contour_path, facecolor="gray", alpha=0.2, edgecolor=None) + ax.add_patch(patch) + + ax.invert_xaxis() + ax.axis("off") diff --git a/CorpusCallosum/shape/postprocessing.py b/CorpusCallosum/shape/postprocessing.py new file mode 100644 index 000000000..e335ccf8b --- /dev/null +++ b/CorpusCallosum/shape/postprocessing.py @@ -0,0 +1,610 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import concurrent.futures +from copy import copy +from functools import partial +from pathlib import Path +from typing import get_args + +import numpy as np + +import FastSurferCNN.utils.logging as logging +from CorpusCallosum.data.constants import CC_LABEL, FSAVERAGE_MIDDLE, SUBSEGMENT_LABELS +from CorpusCallosum.shape.contour import CCContour +from CorpusCallosum.shape.endpoint_heuristic import get_endpoints +from CorpusCallosum.shape.mesh import CCMesh +from CorpusCallosum.shape.metrics import calculate_cc_index +from CorpusCallosum.shape.subsegment_contour import ( + ContourList, + get_primary_eigenvector, + hampel_subdivide_contour, + subdivide_contour, + subsegment_midline_orthogonal, + transform_to_acpc_standard, +) +from CorpusCallosum.shape.thickness import cc_thickness, convert_to_ras +from CorpusCallosum.utils.types import CCMeasuresDict, ContourThickness, Points2dType, SliceSelection, SubdivisionMethod +from CorpusCallosum.utils.visualization import plot_contours +from FastSurferCNN.utils import AffineMatrix4x4, Image3d, ScalarType, Shape2d, Shape3d, Vector2d, Mask2d +from FastSurferCNN.utils.common import SubjectDirectory, suppress_stdout, update_docstring +from FastSurferCNN.utils.parallel import process_executor, thread_executor + +logger = logging.get_logger(__name__) + +# assert LIA orientation +LIA_ORIENTATION = np.zeros((3,3)) +LIA_ORIENTATION[0,0] = -1 +LIA_ORIENTATION[1,2] = 1 +LIA_ORIENTATION[2,1] = -1 + + +def create_sag_slice_vox2vox(slice_idx: int, fsaverage_middle: float) -> AffineMatrix4x4: + """Create slice-specific slice to full affine transformation matrix. + + Returns a volume to slice in volume affine. + + Parameters + ---------- + slice_idx : int + Index of the slice to transform. + fsaverage_middle : float + Reference middle slice index in fsaverage space. + + Returns + ------- + np.ndarray + Modified 4x4 affine transformation matrix for the specific slice. + """ + slice2full_vox2vox: AffineMatrix4x4 = np.eye(4, dtype=float) + slice2full_vox2vox[0, 3] = -fsaverage_middle + slice_idx + return slice2full_vox2vox + + +@update_docstring(SubdivisionMethod=str(get_args(SubdivisionMethod))[1:-1]) +def recon_cc_surf_measures_multi( + segmentation: np.ndarray[Shape3d, np.dtype[np.int_]], + slice_selection: SliceSelection, + fsavg_vox2ras: AffineMatrix4x4, + midslices: Image3d, + ac_coords: Vector2d, + pc_coords: Vector2d, + num_thickness_points: int, + subdivisions: list[float], + subdivision_method: SubdivisionMethod, + contour_smoothing: int, + subject_dir: SubjectDirectory, + vox_size: tuple[float, float, float], +) -> tuple[list[CCMeasuresDict], list[concurrent.futures.Future]]: + """Surface reconstruction and metrics computation of corpus callosum slices based on selection mode. + + Parameters + ---------- + segmentation : np.ndarray + 3D segmentation array. + slice_selection : str + Which slices to process ('middle', 'all', or slice number). + fsavg_vox2ras : np.ndarray + Base affine transformation matrix (fsaverage, upright space). + midslices : np.ndarray + Array of mid-sagittal slices. + ac_coords : np.ndarray + Anterior commissure coordinates. + pc_coords : np.ndarray + Posterior commissure coordinates. + num_thickness_points : int + Number of points for thickness estimation. + subdivisions : list[float] + List of fractions for anatomical subdivisions. + subdivision_method : {SubdivisionMethod} + Method for contour subdivision. + contour_smoothing : int + Gaussian sigma for contour smoothing. + subject_dir : SubjectDirectory + The SubjectDirectory object managing file names in the subject directory. + vox_size : 3-tuple of floats + LIA-oriented voxel size in millimeters (x, y, z). + + Returns + ------- + list of CCMeasuresDict + List of slice processing results. + list of concurrent.futures.Future + List of background IO processes. + """ + slice_cc_measures: list[CCMeasuresDict] = [] + io_futures = [] + + if subdivision_method == "angular" and not np.allclose(np.diff(subdivisions), np.diff(subdivisions)[0]): + raise ValueError( + f"Angular subdivision method (Hampel) only supports equidistant subdivision, " + f"but got: {subdivisions}. No measures are computed.", + ) + + _each_slice = partial( + recon_cc_surf_measure, + segmentation, + ac_coords=ac_coords, + pc_coords=pc_coords, + num_thickness_points=num_thickness_points, + subdivisions=subdivisions, + subdivision_method=subdivision_method, + contour_smoothing=contour_smoothing, + vox_size=vox_size, + ) + + # Process multiple slices or specific slice + if slice_selection == "middle": + num_slices = 1 + # Process only the middle slice + slices_to_recon = [segmentation.shape[0] // 2] + elif slice_selection == "all": + num_slices = segmentation.shape[0] + start_slice = 0 + end_slice = segmentation.shape[0] + slices_to_recon = range(start_slice, end_slice) + else: # specific slice number + num_slices = 1 + slices_to_recon = [int(slice_selection)] + + _gen_fsavg2slice_vox2vox = partial(create_sag_slice_vox2vox, fsaverage_middle=FSAVERAGE_MIDDLE) + per_slice_vox2ras = fsavg_vox2ras @ np.stack(list(map(_gen_fsavg2slice_vox2vox, slices_to_recon)), axis=0) + + per_slice_recon = process_executor().map(_each_slice, slices_to_recon, per_slice_vox2ras, chunksize=1) + cc_contours = [] + + run = thread_executor().submit + for i, (slice_idx, _results) in enumerate(zip(slices_to_recon, per_slice_recon, strict=True)): + progress = f" ({i+1} of {num_slices})" if num_slices > 1 else "" + logger.info(f"Calculating CC measurements for slice {slice_idx+1}{progress}") + # unpack values from _results + cc_measures: CCMeasuresDict = _results[0] + contour_in_as_space_and_thickness: ContourThickness = _results[1] + endpoint_idxs: tuple[int, int] = _results[2] + contour_in_as_space: Points2dType = contour_in_as_space_and_thickness[:, :2] + thickness_values: np.ndarray[tuple[int], np.dtype[np.float_]] = contour_in_as_space_and_thickness[:, 2] + + cc_contours.append(CCContour(contour_in_as_space, thickness_values, endpoint_idxs, resolution=vox_size[0])) + if cc_measures is None: + # this should not happen, but just in case + logger.warning(f"Slice index {slice_idx+1}{progress} returned result `None`") + + slice_cc_measures.append(cc_measures) + is_debug = logger.getEffectiveLevel() <= logging.DEBUG + is_midslice = slice_idx == num_slices // 2 + if subject_dir.has_attribute("cc_qc_image") and (is_debug or is_midslice): + qc_imgs: list[Path] = [subject_dir.filename_by_attribute("cc_qc_image")] + if is_debug: + qc_slice_img = qc_imgs[0].with_suffix(f".slice_{slice_idx}.png") + qc_imgs = (qc_imgs if is_midslice else []) + [qc_slice_img] + + logger.info(f"Saving segmentation qc image to {', '.join(map(str, qc_imgs))}") + current_slice_in_volume = midslices.shape[0] // 2 - num_slices // 2 + slice_idx + # Create visualization for this slice + io_futures.append( + run( + plot_contours, + transformed=midslices[current_slice_in_volume:current_slice_in_volume+1], + split_contours=cc_measures["split_contours"], + midline_equidistant=cc_measures["midline_equidistant"], + levelpaths=cc_measures["levelpaths"], + output_path=qc_imgs, + ac_coords=ac_coords, + pc_coords=pc_coords, + vox_size=vox_size, + title=f"CC Subsegmentation by {subdivision_method} (Slice {slice_idx + 1})", + ) + ) + + + if subject_dir.has_attribute("save_template_dir"): + template_dir = subject_dir.filename_by_attribute("save_template_dir") + # ensure directory exists + template_dir.mkdir(parents=True, exist_ok=True) + logger.info("Saving template files (contours.txt, thickness_values.txt, " + f"thickness_measurement_points.txt) to {template_dir}") + run = run + for j in range(len(cc_contours)): + # FIXME: check, if this is fixed (thickness values not nan == 200) + # this does not seem to be thread-safe, do not parallelize! + io_futures.append(run(cc_contours[j].save_contour, template_dir / f"contour_{j}.txt")) + io_futures.append(run(cc_contours[j].save_thickness_values, template_dir / f"thickness_values_{j}.txt")) + + mesh_outputs = ("html", "mesh", "thickness_overlay", "surf", "thickness_image") + if len(cc_contours) > 1 and any(subject_dir.has_attribute(f"cc_{n}") for n in mesh_outputs): + _cc_contours = thread_executor().map(_resample_thickness, cc_contours) + cc_mesh = CCMesh.from_contours(list(_cc_contours), smooth=1) + if subject_dir.has_attribute("cc_html"): + logger.info(f"Saving CC 3D visualization to {subject_dir.filename_by_attribute('cc_html')}") + io_futures.append(run( + cc_mesh.plot_mesh,output_path=subject_dir.filename_by_attribute("cc_html")), + ) + + if subject_dir.has_attribute("cc_mesh"): + vtk_file_path = subject_dir.filename_by_attribute("cc_mesh") + logger.info(f"Saving vtk file to {vtk_file_path}") + io_futures.append(run(cc_mesh.write_vtk, vtk_file_path)) + + # the mesh is generated in upright coordinates, so we need to also transform to orig coordinates + cc_mesh = cc_mesh.to_fs_coordinates(lr_offset=FSAVERAGE_MIDDLE / vox_size[0]) + if subject_dir.has_attribute("cc_thickness_overlay"): + overlay_file_path = subject_dir.filename_by_attribute("cc_thickness_overlay") + logger.info(f"Saving overlay file to {overlay_file_path}") + io_futures.append(run(cc_mesh.write_morph_data, overlay_file_path)) + + if subject_dir.has_attribute("cc_surf"): + surf_file_path = subject_dir.filename_by_attribute("cc_surf") + logger.info(f"Saving surf file to {surf_file_path}") + io_futures.append(run(cc_mesh.write_fssurf, str(surf_file_path), subject_dir.conf_name)) + + if subject_dir.has_attribute("cc_thickness_image"): + thickness_image_path = subject_dir.filename_by_attribute("cc_thickness_image") + logger.info(f"Saving thickness image to {thickness_image_path}") + cc_mesh.snap_cc_picture(thickness_image_path, subject_dir.conf_name) + + if not slice_cc_measures: + logger.error("Error: No valid slices were found for postprocessing") + raise ValueError("No valid slices were found for postprocessing") + + return slice_cc_measures, io_futures + + +def _resample_thickness(contour: CCContour) -> CCContour: + """Resamples the thickness values of contour.""" + _c = copy(contour) + _c.fill_thickness_values() + return _c + + +def recon_cc_surf_measure( + segmentation: np.ndarray[Shape2d, np.dtype[np.int_]], + slice_idx: int, + affine: AffineMatrix4x4, + ac_coords: Vector2d, + pc_coords: Vector2d, + num_thickness_points: int, + subdivisions: list[float], + subdivision_method: SubdivisionMethod, + contour_smoothing: int, + vox_size: tuple[float, float, float], +) -> tuple[CCMeasuresDict, ContourThickness, tuple[int, int]]: + """Reconstruct surfaces and compute measures for a single slice for the corpus callosum. + + Parameters + ---------- + segmentation : np.ndarray + 3D segmentation array. + slice_idx : int + Index of the slice to process. + affine : AffineMatrix4x4 + 4x4 affine transformation matrix. + ac_coords : np.ndarray of shape (2,) and type float + Anterior commissure coordinates. + pc_coords : np.ndarray of shape (2,) and type float + Posterior commissure coordinates. + num_thickness_points : int + Number of points for thickness estimation. + subdivisions : list[float] + List of fractions for anatomical subdivisions. + subdivision_method : SubdivisionMethod + Method for contour subdivision ('shape', 'vertical', 'angular', or 'eigenvector'). + contour_smoothing : int + Gaussian sigma for contour smoothing. + vox_size : triplet of floats + LIA-oriented voxel size in millimeters. + + Returns + ------- + measures : CCMeasuresDict + Dictionary containing measurements if successful. + contour_with_thickness : np.ndarray + Contour points with thickness information, shape (3, N) for [x, y, thickness]. + endpoint_indices : pair of ints + Indices of the anterior and posterior endpoints on the contour. + + Raises + ------ + ValueError + If no CC is found in the specified slice. + + Notes + ----- + The function performs the following steps: + 1. Extracts CC contour and identifies endpoints. + 2. Converts coordinates to RAS space. + 3. Calculates thickness profile using Laplace equation. + 4. Computes shape metrics and subdivisions. + 5. Generates visualization data. + """ + cc_mask_slice: Mask2d = np.equal(segmentation[slice_idx], CC_LABEL) + if not np.any(cc_mask_slice): + raise ValueError(f"No CC found in slice {slice_idx}") + contour, endpoint_idxs = get_endpoints( + cc_mask_slice, + ac_coords, + pc_coords, + (vox_size[1], vox_size[2]), + return_coordinates=False, + contour_smoothing=contour_smoothing, + ) + contour_ras = convert_to_ras(contour, affine) + + endpoint_idxs: tuple[int, int] + contour_with_thickness: ContourThickness + midline_len, thickness, curvature, midline_equi, levelpaths, contour_with_thickness, endpoint_idxs = cc_thickness( + contour_ras[1:].T, + endpoint_idxs, + n_points=num_thickness_points, + ) + # thickness values in contour_with_thickness is not equally sampled, different shape + # to compute length of paths: diff between consecutive points (N-1, 2) => norm (N-1,) => sum (1,) + thickness_profile = np.stack([np.sum(np.linalg.norm(np.diff(x[:, :2], axis=0), axis=1)) for x in levelpaths]) + + acpc_contour_coords_ras = contour_ras[:, list(endpoint_idxs)].T + contour_in_acpc_space, ac_pt_acpc, pc_pt_acpc, rotate_back_acpc = transform_to_acpc_standard( + contour_ras[1:], + *acpc_contour_coords_ras[:, 1:], + ) + cc_index = calculate_cc_index(contour_in_acpc_space) + + # Apply different subdivision methods based on user choice + split_contours: ContourList + if subdivision_method == "shape": + _subdivisions = np.asarray(subdivisions) + areas, split_contours = subsegment_midline_orthogonal(midline_equi, _subdivisions, contour_ras[1:], plot=False) + split_contours = [transform_to_acpc_standard(split_contour, *acpc_contour_coords_ras[:, 1:])[0] + for split_contour in split_contours] + elif subdivision_method == "vertical": + areas, split_contours = subdivide_contour(contour_in_acpc_space, subdivisions, plot=False) + elif subdivision_method == "angular": + if not np.allclose(np.diff(subdivisions), np.diff(subdivisions)[0]): + raise ValueError( + f"Angular subdivision method (Hampel) only supports equidistant subdivision, " + f"but got: {subdivisions}. No measures are computed.", + ) + areas, split_contours = hampel_subdivide_contour(contour_in_acpc_space, num_rays=len(subdivisions), plot=False) + elif subdivision_method == "eigenvector": + pt0, pt1 = get_primary_eigenvector(contour_in_acpc_space) + contour_eigen, _, _, rotate_back_eigen = transform_to_acpc_standard(contour_in_acpc_space, pt0, pt1) + ac_pt_eigen, _, _, _ = transform_to_acpc_standard(ac_pt_acpc[:, None], pt0, pt1) + ac_pt_eigen = ac_pt_eigen[:, 0] + areas, split_contours = subdivide_contour(contour_eigen, subdivisions, oriented=True, hline_anchor=ac_pt_eigen) + split_contours = [rotate_back_eigen(split_contour) for split_contour in split_contours] + else: + raise ValueError(f"Invalid subdivision method {subdivision_method}") + + total_area = np.sum(areas) + total_perimeter = np.sum(np.sqrt(np.sum((np.diff(contour_ras[:, 1:], axis=0))**2, axis=1))) + circularity = 4 * np.pi * total_area / (total_perimeter**2) + + # Transform split contours back to original space + split_contours = [rotate_back_acpc(split_contour) for split_contour in split_contours] + + measures: CCMeasuresDict = { + "cc_index": cc_index, + "circularity": circularity, + "areas": np.asarray(areas), + "midline_length": midline_len, + "thickness": thickness, + "curvature": curvature, + "thickness_profile": thickness_profile, + "total_area": total_area, + "total_perimeter": total_perimeter, + "split_contours": split_contours, + "midline_equidistant": midline_equi, + "levelpaths": levelpaths, + "slice_index": slice_idx + } + return measures, contour_with_thickness, endpoint_idxs + + +def vectorized_line_test( + coords_x: np.ndarray[tuple[int], np.dtype[ScalarType]], + coords_y: np.ndarray[tuple[int], np.dtype[ScalarType]], + line_start: Vector2d, + line_end: Vector2d, +) -> np.ndarray[tuple[int], np.dtype[np.bool_]]: + """Vectorized version of point_relative_to_line for arrays of points. + + Parameters + ---------- + coords_x : np.ndarray + Array of x coordinates. + coords_y : np.ndarray + Array of y coordinates. + line_start : array-like + [x, y] coordinates of line start point. + line_end : array-like + [x, y] coordinates of line end point. + + Returns + ------- + np.ndarray + Boolean array where True means point is to the left of the line. + """ + # FIXME: rename this function to something more indicative + # Vector from line_start to line_end + line_vec = np.array(line_end) - np.array(line_start) + + # Vectors from line_start to all points (vectorized) + point_vec_x = coords_x - line_start[0] + point_vec_y = coords_y - line_start[1] + + # Cross product (vectorized): positive means point is to the left of the line + cross_products = line_vec[0] * point_vec_y - line_vec[1] * point_vec_x + + return cross_products > 0 + + +def get_unique_contour_points(split_contours: ContourList) -> list[Points2dType]: + """Get unique contour points from the split contours. + + Parameters + ---------- + split_contours : ContourList + List of split contours (subsegmentations), each containing x and y coordinates, each of shape (2, N). + + Returns + ------- + list[np.ndarray] + List of unique contour points for each subsegment, each of shape (N, 2). + + Notes + ----- + This is a workaround to retrospectively add voxel-based subdivision. + In the future, we could keep track of the subdivision lines for + every subdivision scheme. + + The function: + 1. Processes each contour point. + 2. Checks if it appears in other contours (with small tolerance). + 3. Collects points unique to each subsegment. + """ + # For each contour point, check if it appears in other contours + unique_contour_points: list[Points2dType] = [] + + for i, contour in enumerate(split_contours): + # Get points for this contour + contour_points: Points2dType = np.vstack((contour[0], -contour[1])).T # Shape: (N,2) + + # Check each point against all other contours + unique_points = [] + for point in contour_points: + is_unique = True + + # Compare against other contours + for j, other_contour in enumerate(split_contours): + if i == j: + continue + + other_points = np.vstack((other_contour[0], -other_contour[1])).T + + # Check if point exists in other contour (with small tolerance) + if np.any(np.all(np.abs(other_points - point) < 1e-6, axis=1)): + is_unique = False + break + + if is_unique: + unique_points.append(point) + + unique_contour_points.append(np.array(unique_points)) + + return unique_contour_points + + +def make_subdivision_mask( + slice_shape: Shape2d, + split_contours: ContourList, + vox_size: tuple[float, float], +) -> np.ndarray[Shape2d, np.dtype[np.int_]]: + """Create a mask for subdividing the corpus callosum based on split contours. + + Parameters + ---------- + slice_shape : pair of ints + Shape of the slice (rows, cols). + split_contours : ContourList + List of contours defining the subdivisions. + Each contour is a tuple of x and y coordinates. + vox_size : pair of floats + The voxel sizes of the image grid in AS orientation. + + Returns + ------- + np.ndarray + A mask of shape slice_shape where each pixel is labeled with a value + from SUBSEGEMNT_LABELS indicating which subdivision segment it belongs to. + + Notes + ----- + The function: + 1. Extracts unique contour points at subdivision boundaries. + 2. Creates coordinate grids for all points in the slice. + 3. Initializes mask with first segment label. + 4. For each subdivision line: + - Tests which points lie to the right of the line. + - Updates labels for those points. + """ + + # unique contour points are the points where sub-division lines were inserted + unique_contour_points: list[Points2dType] = get_unique_contour_points(split_contours) # shape (N, 2) + subdivision_segments = unique_contour_points[1:] + + for s in subdivision_segments: + if len(s) != 2: + logger.error(f"Subdivision segment {s} has {len(s)} points, expected 2") + + # Create coordinate grids for all points in the slice + rows, cols = slice_shape + y_coords, x_coords = np.mgrid[0:rows, 0:cols] + + cc_subsegment_lut_anterior_to_posterior = SUBSEGMENT_LABELS.copy() + cc_subsegment_lut_anterior_to_posterior.reverse() + + # Initialize with first segment label + subdivision_mask = np.full(slice_shape, cc_subsegment_lut_anterior_to_posterior[0], dtype=np.int32) + + # Process each subdivision line + for segment_idx, segment_points in enumerate(subdivision_segments): + # FIXME: names for line_start and line_end? + line_start: Vector2d = segment_points[0] / vox_size + line_end: Vector2d = segment_points[-1] / vox_size + + # Vectorized test: find all points to the right of this line + # FIXME: line defined by what? Is this inside the polygon or the line from line_start to line_end? + points_right_of_line = vectorized_line_test(x_coords, y_coords, line_start, line_end) + + # All points to the right of this line belong to the next segment or beyond + subdivision_mask[points_right_of_line] = cc_subsegment_lut_anterior_to_posterior[segment_idx + 1] + + return subdivision_mask + + +def check_area_changes(contours: list[np.ndarray], threshold: float = 0.3) -> bool: + """Check for large changes between consecutive CC areas and issue warnings. + + Parameters + ---------- + contours : list[np.ndarray] + List of contours. + threshold : float, default=0.3 + Threshold for relative change. + + Returns + ------- + bool + True if no large area changes are detected, False otherwise. + """ + + areas = np.asarray([np.sum(np.sqrt(np.sum((np.diff(contour, axis=0))**2, axis=1))) for contour in contours]) + + assert len(areas) > 1, "At least two areas are required to check for area changes" + + if np.any(areas == 0): + # One area is zero, the other is not - this is a 100% change + logger.warning(f"Areas {np.where(areas == 0)[0].tolist()} are zero mm²") + return False + + # Calculate relative change + relative_change = np.abs(np.diff(areas)) / areas[:-1] + + if np.any(where_change := relative_change > threshold): + indices = np.where(where_change)[0] + percent_change = relative_change[where_change] * 100 + logger.info( + f"Large corpus callosum area change after slices {indices.tolist()} detected: " + + ", ".join(f"areas {(i,i+1)} = ({areas[i]:.2f},{areas[i+1]:.2f}) mm² ({p:.1f}% change)" + for i, p in zip(indices, percent_change, strict=True)) + ) + return False + return True diff --git a/CorpusCallosum/shape/subsegment_contour.py b/CorpusCallosum/shape/subsegment_contour.py new file mode 100644 index 000000000..2aa1b8fd2 --- /dev/null +++ b/CorpusCallosum/shape/subsegment_contour.py @@ -0,0 +1,949 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable +from typing import TYPE_CHECKING, Literal + +import matplotlib.pyplot as plt +import numpy as np +from scipy.spatial import ConvexHull + +from CorpusCallosum.utils.types import ContourList, Points2dType, Polygon2dType, Polygon3dType +from FastSurferCNN.utils import Mask2d, Mask3d, ScalarType, Vector2d, nibabelImage + +if TYPE_CHECKING: + import pandas as pd + +def minimum_bounding_rectangle(points: Points2dType) -> np.ndarray[tuple[Literal[4], Literal[2]], np.dtype[ScalarType]]: + """Find the smallest bounding rectangle for a set of points. + + Parameters + ---------- + points : array + An array of shape (N, 2) containing point coordinates. + + Returns + ------- + np.ndarray + Array of shape (4, 2) containing coordinates of the bounding box corners. + """ + pi2 = np.pi / 2.0 + points = np.asarray(points).T + + # get the convex hull for the points + hull_points = points[ConvexHull(points).vertices] + + # calculate edge angles + edges = hull_points[1:] - hull_points[:-1] + + angles = np.arctan2(edges[:, 1], edges[:, 0]) + + angles = np.abs(np.mod(angles, pi2)) + angles = np.unique(angles) + + # find rotation matrices + rotations = np.vstack([np.cos(angles), np.cos(angles - pi2), np.cos(angles + pi2), np.cos(angles)]).T + rotations = rotations.reshape((-1, 2, 2)) + + # apply rotations to the hull + rot_points = np.dot(rotations, hull_points.T) + + # find the bounding points + min_x = np.nanmin(rot_points[:, 0], axis=1) + max_x = np.nanmax(rot_points[:, 0], axis=1) + min_y = np.nanmin(rot_points[:, 1], axis=1) + max_y = np.nanmax(rot_points[:, 1], axis=1) + + # find the box with the best area + areas = (max_x - min_x) * (max_y - min_y) + best_idx = np.argmin(areas) + + # return the best box + x1 = max_x[best_idx] + x2 = min_x[best_idx] + y1 = max_y[best_idx] + y2 = min_y[best_idx] + r = rotations[best_idx] + + rval = np.zeros((4, 2)) + rval[0] = np.dot([x1, y2], r) + rval[1] = np.dot([x2, y2], r) + rval[2] = np.dot([x2, y1], r) + rval[3] = np.dot([x1, y1], r) + + return rval + + +def calc_subsegment_areas(split_contours: ContourList) -> np.ndarray[tuple[int], np.dtype[ScalarType]]: + """Calculate area of each subsegment using the shoelace formula. + + Parameters + ---------- + split_contours : list of np.ndarray + List of contour arrays, each of shape (2, N). + + Returns + ------- + subsegment_areas : array of floats + Array containing the area of each subsegment. + """ + # calculate area of each split contour using the shoelace formula + areas = np.abs([np.trapz(split_contour[1], split_contour[0]) for split_contour in split_contours]) + if len(areas) == 1: + return np.asarray(areas[0]) + return np.ediff1d(np.asarray(areas)[::-1], to_end=areas[-1]) + + +def subsegment_midline_orthogonal( + midline: Points2dType, + area_weights: np.ndarray[tuple[int], np.dtype[np.float_]], + contour: Polygon2dType, + plot: bool = True, + ax=None, + extremes=None, +) -> tuple[np.ndarray[tuple[int], np.dtype[ScalarType]], ContourList]: + """Subsegment contour orthogonally to the midline based on area weights. + + Parameters + ---------- + midline : array of floats + Array of shape (N, 2) containing midline points. + area_weights : array of floats + Array of weights for area-based subdivision. + contour : array of floats + Array of shape (2, M) containing contour points in as space. + plot : bool, optional + Whether to plot the results, by default True. + ax : matplotlib.axes.Axes, optional + Axes for plotting, by default None. + extremes : tuple, optional + Tuple of extreme points, by default None. + + Returns + ------- + subsegment_areas : array of floats + List of subsegment areas. + split_contours : list of np.ndarray + List of contour arrays for each subsegment. + """ + # FIXME: Here and in other places, the order of dimensions is pretty inconsistent, for example: midline is (N, 2), + # but contours are (2, N)... + + # FIXME: why does this code return subsegments that include all previous segments? + # get points after midline length of splits + + # get vertex closest to midline end + midline_end_idx = np.argmin(np.linalg.norm(contour.T - midline[-1], axis=1)) + # roll contour start to midline end + contour = np.roll(contour, -midline_end_idx, axis=1) + + edge_idx, edge_frac = np.divmod(len(midline) * np.array(area_weights), 1) + edge_idx = edge_idx.astype(int) + split_points = midline[edge_idx] + (midline[edge_idx + 1] - midline[edge_idx]) * edge_frac[:, None] + + # get edge for each split point + edge_directions = midline[edge_idx] - midline[edge_idx + 1] + # get vector perpendicular to each midline edge + edge_ortho_vectors = np.column_stack((-edge_directions[:, 1], edge_directions[:, 0])) + edge_ortho_vectors = edge_ortho_vectors / np.linalg.norm(edge_ortho_vectors, axis=1)[:, None] + + split_contours: ContourList = [contour] + + # FIXME: double loop should be vectorized, see commented code below for an initial attempt (not tested) + # also, finding intersections can be done more efficiently, instead of solving linear system for each segment + # we could just look for changes in the sign of cross products + # mid_to_contour: np.ndarray = contour[:, :, None] - split_points[:, None] + # mid_to_contour_length = np.linalg.norm(mid_to_contour, axis=0) + # mid_to_contour_norm = mid_to_contour / mid_to_contour_length[None] + # sin_theta = mid_to_contour_norm[0] * edge_ortho_vectors[1] - mid_to_contour_norm[1] * edge_ortho_vectors[0] + # index_on_contour, index_on_segment = np.where(sin_theta[:-1] * sin_theta[1:] < 0) + # sin_theta_x = sin_theta[index_on_segment] + # cos_theta_x = np.sqrt(1 - sin_theta_x * sin_theta_x) + # rot_mat = np.array([[cos_theta_x, -sin_theta_x], [sin_theta_x, cos_theta_x]]) + # # rotate mid_to_contour by sin_theta + # _mid_to_intersection = rot_mat.transpose(0, -1) @ mid_to_contour[:, None, (index_on_contour, index_on_segment)] + # mid_to_intersection = cos_theta_x * _mid_to_intersection[:, 0, :] + # intersection_points = split_points[:, index_on_segment] + mid_to_intersection + # mid_to_intersection_length = np.linalg.norm(mid_to_intersection, axis=0) + # + # + # for segment_idx in range(split_points.shape[1]): + # mask = index_on_segment == segment_idx + # if any(mask): + # # first_index and second_index are the indices on the contour + # # _first_index and _second_index are the indices on the intersection_points of this segment + # _first_index, _second_index, *_ = np.argsort(mid_to_intersection_length[mask]) + # first_index, second_index = index_on_contour[mask][[_first_index, _second_index]] + # if first_index > second_index: + # first_index, second_index = second_index, first_index + # _first_index, _second_index = _second_index, _first_index + # # connect first and second half + # start_to_cutoff = np.hstack( + # ( + # contour[:, :first_index + 1], # includes first_index + # intersection_points[:, mask][:, [_first_index, _second_index]], + # contour[:, second_index + 1 :], # excludes second_index + # ) + # ) + # split_contours.append(start_to_cutoff) + + for pt_idx, split_point in enumerate(split_points): + intersections = [] + for i in range(contour.shape[1] - 1): + # get contour segment + segment_start = contour[:, i] + segment_end = contour[:, i + 1] + segment_vector = segment_end - segment_start + + # Check for intersection with the perpendicular line + matrix = np.array([segment_vector, -edge_ortho_vectors[pt_idx]]).T + if np.linalg.matrix_rank(matrix) < 2: + continue # Skip parallel lines + + # Solve for intersection + t, s = np.linalg.solve(matrix, split_point - segment_start) + if 0 <= t <= 1: + intersection_point = segment_start + t * segment_vector + intersections.append((i, intersection_point)) + + # import matplotlib.pyplot as plt + # plt.figure() + # plt.plot(contour[0], contour[1], 'k-') + # plt.plot(midline[:,0], midline[:,1], 'k--') + # plt.plot(split_point[0], split_point[1], 'ro') + + # plt.plot([segment_start[0], segment_end[0]], [segment_start[1], segment_end[1]], 'bo', linewidth=2) + # plt.plot([split_point[0]-edge_ortho_vectors[pt_idx][0], split_point[0]+edge_ortho_vectors[pt_idx][0]], + # [split_point[1]-edge_ortho_vectors[pt_idx][1], + # split_point[1]+edge_ortho_vectors[pt_idx][1]], 'k-', linewidth=2) + # plt.show() + + # get the two intersections closest to split_point + intersections.sort(key=lambda x: np.linalg.norm(x[1] - split_point)) + + # Create new contours by splitting at intersections + if intersections: + first_index, first_intersection = intersections[1] + second_index, second_intersection = intersections[0] + + if first_index > second_index: + first_index, second_index = second_index, first_index + first_intersection, second_intersection = second_intersection, first_intersection + + first_index += 1 + # second_index += 1 + + # connect first and second half + start_to_cutoff = np.hstack( + ( + contour[:, :first_index], + first_intersection[:, None], + second_intersection[:, None], + contour[:, second_index + 1 :], + ) + ) + split_contours.append(start_to_cutoff) + else: + raise ValueError("No intersections found, this should not happen") + + # plot contour to first index, then split point, then contour to second index + + # import matplotlib.pyplot as plt + # plt.close() + # fig, ax = plt.subplots(1,1) + # ax.plot(contour[:, :first_index][0], contour[:, :first_index][1], '-', linewidth=2, color='grey', + # label='Contour to first index') + # ax.plot(first_intersection[0], first_intersection[1], 'o', markersize=8, color='red', + # label='First intersection') + # ax.plot(second_intersection[0], second_intersection[1], 'o', markersize=8, color='red', + # label='Second intersection') + # ax.plot(contour[:, second_index + 1:][0], contour[:, second_index + 1:][1], '-', linewidth=2, color='red', + # label='Contour to second index') + # ax.legend() + # ax.set_title('Split Contours') + # ax.set_aspect('equal') + # ax.axis('off') + # plt.show() + + if plot: + extremes = [midline[0], midline[-1]] + + plot_transform = None + if plot_transform is not None: + split_contours = [plot_transform(split_contour) for split_contour in split_contours] + contour = plot_transform(contour) + extremes = [plot_transform(extreme[:, None]) for extreme in extremes] + split_points = [plot_transform(split_point[:, None]) for split_point in split_points] + # split_points_vlines_start = plot_transform(split_points_vlines_start) + # split_points_vlines_end = plot_transform(split_points_vlines_end) + + import matplotlib.pyplot as plt + + if ax is None: + SHOW = True + fig, ax = plt.subplots(1, 1, figsize=(8, 6)) + ax.axis("equal") + else: + SHOW = False + # pretty plot with areas filled in the polygon and overall area annotated + colors = plt.cm.Spectral(np.linspace(0.2, 0.8, len(split_contours))) + for color, split_contour in zip(colors, split_contours, strict=True): + ax.fill(split_contour[0], split_contour[1], alpha=0.5, color=color) + # ax.text(np.mean(split_contour[0]), np.mean(split_contour[1]), f'{area_out[i]:.2f}', + # olor='black', fontsize=12) + # plot contour + ax.plot(contour[0], contour[1], "-", linewidth=2, color="grey") + # put text between split points + # add endpoints to split_points + split_points = split_points.tolist() + split_points.insert(0, extremes[0]) + split_points.append(extremes[1]) + # ax.scatter(np.array(split_points)[:,0], np.array(split_points)[:,1], color='black', s=20) + ax.plot(midline[:, 0], midline[:, 1], "k--", linewidth=2) + + # plot edge orthogonal to each split point + for i in range(0, len(edge_ortho_vectors)): + pt = split_points[i + 1] + length = 0.4 + ax.plot( + [pt[0] - edge_ortho_vectors[i][0] * length, pt[0] + edge_ortho_vectors[i][0] * length], + [pt[1] - edge_ortho_vectors[i][1] * length, pt[1] + edge_ortho_vectors[i][1] * length], + "k-", + linewidth=2, + ) + + # convert area_weights into fraction of total line length + # e.g. area_weights=[1/6, 1/2, 2/3, 3/4] to ['1/6', '2/3', ...] + # cumulative difference + area_weights_diff = [area_weights[0]] + for i in range(1, len(area_weights)): + area_weights_diff.append(area_weights[i] - area_weights[i - 1]) + area_weights_diff.append(1 - area_weights[-1]) + + for i in range(len(split_points) - 1): + # get_index of split_points[i] in midline + sp1_midline_idx = np.argmin(np.linalg.norm(midline - split_points[i], axis=1)) + sp2_midline_idx = np.argmin(np.linalg.norm(midline - split_points[i + 1], axis=1)) + + # get midpoint on midline + midpoint_idx = (sp1_midline_idx + sp2_midline_idx) // 2 + midpoint = midline[midpoint_idx] + + # get vector perpendicular to line between split points + vector = np.array(split_points[i + 1]) - np.array(split_points[i]) + vector = vector / np.linalg.norm(vector) + vector = np.array([-vector[1], vector[0]]) + + midpoint = midpoint - vector * 3 + # ax.text(midpoint[0]-5, midpoint[1]-5, f'{area_out[i]:.2f}', color='black', fontsize=12) + # ax.text(midpoint[0], midpoint[1], f'{area_weights_txt[i]}', color='black', fontsize=12, + # horizontalalignment='center', verticalalignment='center') + + # start point & end point + ax.plot(extremes[0][0], extremes[0][1], marker="o", markersize=8, color="black") + ax.plot(extremes[1][0], extremes[1][1], marker="o", markersize=8, color="black") + + # plot contour point 0 + # ax.scatter(contour[0,0], contour[1,0], color='red', s=120) + ax.set_title("Split Contours") + + if SHOW: + ax.axis("off") + ax.invert_xaxis() + ax.axis("equal") + plt.show() + + return calc_subsegment_areas(split_contours), split_contours + + +def hampel_subdivide_contour(contour: Polygon2dType, num_rays: int, plot: bool = False, ax=None) \ + -> tuple[np.ndarray[tuple[int], np.dtype[np.float_]], ContourList]: + # FIXME: needs docstring + # Find the extreme points in the x-direction + min_x_index = np.argmin(contour[0]) + contour = np.roll(contour, -min_x_index, axis=1) + + # get minimal bounding box around contour + min_bounding_rectangle = minimum_bounding_rectangle(contour) + + # get long edges of rectangle + rectangle_duplicate_last = np.vstack((min_bounding_rectangle, min_bounding_rectangle[0])) + long_edges = np.diff(rectangle_duplicate_last, axis=0) + long_edges = np.linalg.norm(long_edges, axis=1) + long_edges_idx = np.argpartition(long_edges, -2)[-2:] + + # select lower long edge + min_val = np.inf + min_idx = None + for i in long_edges_idx: + if rectangle_duplicate_last[i][1] < min_val: + min_val = rectangle_duplicate_last[i][1] + min_idx = i + + if rectangle_duplicate_last[i + 1][1] < min_val: + min_val = rectangle_duplicate_last[i + 1][1] + min_idx = i + + lowest_points = rectangle_duplicate_last[[min_idx, min_idx + 1]] + + # sort lowest points by x coordinate + if lowest_points[0, 0] < lowest_points[1, 0]: + lowest_points = lowest_points[::-1] + + # get midpoint of lower edge of rectangle + midpoint_lower_edge = np.mean(lowest_points, axis=0) + + # get angle of lower edge of rectangle to x-axis + angle_lower_edge = np.arctan2( + lowest_points[1, 1] - lowest_points[0, 1], lowest_points[1, 0] - lowest_points[0, 0] + ) + + # get angles for equally spaced rays + angles = np.linspace(-angle_lower_edge, -angle_lower_edge + np.pi, num_rays + 2, endpoint=True) # + np.pi *3 + angles = angles[1:-1] + + # create ray vectors + ray_vectors = np.vstack((np.cos(angles), np.sin(angles))) + # make ray vectors unit length + ray_vectors = ray_vectors / np.linalg.norm(ray_vectors, axis=0) + + # invert x of ray vectors + ray_vectors[0] = -ray_vectors[0] + + # Subdivision logic + split_contours: ContourList = [] + for ray_vector in ray_vectors.T: + intersections = [] + for i in range(contour.shape[1] - 1): + segment_start = contour[:, i] + segment_end = contour[:, i + 1] + segment_vector = segment_end - segment_start + + # Check for intersection with the ray + matrix = np.array([segment_vector, -ray_vector]).T + if np.linalg.matrix_rank(matrix) < 2: + continue # Skip parallel lines + + # Solve for intersection + t, s = np.linalg.solve(matrix, midpoint_lower_edge - segment_start) + if 0 <= t <= 1: + intersection_point = segment_start + t * segment_vector + intersections.append((i, intersection_point)) + + # Sort intersections by their position along the contour + intersections.sort() + + # Create new contours by splitting at intersections + if intersections: + first_index, first_intersection = intersections[0] + second_index, second_intersection = intersections[-1] + + start_to_cutoff = np.hstack( + ( + contour[:, :first_index], + first_intersection[:, None], + second_intersection[:, None], + contour[:, second_index + 1 :], + ) + ) + + # connect first and second half + split_contours.append(start_to_cutoff) + else: + raise ValueError("No intersections found, this should not happen") + + split_contours.append(contour) + split_contours = split_contours[::-1] + + # split_contours = split_contours[::-1] + + # Plotting logic + if plot: + import matplotlib.pyplot as plt + + if ax is None: + fig, ax = plt.subplots(1, 1, figsize=(8, 6)) + ax.axis("equal") + SHOW = True + else: + SHOW = False + min_bounding_rectangle_plot = np.vstack((min_bounding_rectangle, min_bounding_rectangle[0])) + # ax.plot(contour[0], contour[1], 'b-', label='Original Contour') + ax.plot(min_bounding_rectangle_plot[:, 0], min_bounding_rectangle_plot[:, 1], "k--") + ax.plot(midpoint_lower_edge[0], midpoint_lower_edge[1], "ko", markersize=8) + for ray_vector in ray_vectors.T: + ray_length = 15 + ray_vector *= -ray_length + ax.plot( + [midpoint_lower_edge[0], midpoint_lower_edge[0] + ray_vector[0]], + [midpoint_lower_edge[1], midpoint_lower_edge[1] + ray_vector[1]], + "k--", + ) + # pretty plot with areas files in the polygon and overall area annotated + colors = plt.cm.Spectral(np.linspace(0.2, 0.8, len(split_contours))) + for color, split_contour in zip(colors, split_contours, strict=True): + ax.fill(split_contour[0], split_contour[1], alpha=0.5, color=color) + ax.plot(contour[0], contour[1], "-", linewidth=2, color="grey") + + ax.set_title("Split Contours") + ax.axis("off") + if SHOW: + ax.axis("equal") + plt.show() + + return calc_subsegment_areas(split_contours), split_contours + + +def subdivide_contour( + contour: Polygon2dType, + area_weights: list[float], + plot: bool = False, + ax: plt.Axes | None = None, + plot_transform: Callable | None = None, + oriented: bool = False, + hline_anchor: np.ndarray | None = None +) -> tuple[np.ndarray[tuple[int], np.dtype[np.float_]], ContourList]: + """Subdivide contour based on area weights using vertical lines. + + Divides the contour into segments by drawing vertical lines at positions + determined by the area weights. The lines are drawn perpendicular to a + reference line connecting the extreme points of the contour. + + Parameters + ---------- + contour : np.ndarray + Array of shape (2, N) containing contour points. + area_weights : list[float] + List of weights for area-based subdivision. + plot : bool, optional + Whether to plot the results, by default False. + ax : matplotlib.axes.Axes, optional + Axes for plotting, by default None. + plot_transform : callable, optional + Function to transform points before plotting, by default None. + oriented : bool, optional + If True, use fixed horizontal reference line, by default False. + hline_anchor : np.ndarray, optional + Point to anchor horizontal reference line, by default None. + + Returns + ------- + areas : np.ndarray + Array of areas for each subsegment. + split_contours : list[np.ndarray] + List of contour arrays for each subsegment. + + Notes + ----- + The subdivision process: + 1. Finds extreme points in x-direction. + 2. Creates reference line between extremes. + 3. Calculates split points based on area weights. + 4. Divides contour using perpendicular lines at split points. + + """ + # Find the extreme points in the x-direction + min_x_index = np.argmax(contour[0]) + contour = np.roll(contour, -min_x_index, axis=1) + + min_x_index = 0 + max_x_index = np.argmin(contour[0]) + + if oriented: + contour_x_sorted = np.sort(contour[0]) + min_x = contour_x_sorted[0] + max_x = contour_x_sorted[-1] + extremes = (np.array([max_x, 0]), np.array([min_x, 0])) + + if hline_anchor is not None: + extremes = (np.array([max_x, hline_anchor[1]]), np.array([min_x, hline_anchor[1]])) + else: + extremes = (contour[:, min_x_index].copy(), contour[:, max_x_index].copy()) + # Calculate the line between the extreme points + start_point, end_point = extremes + line_vector = end_point - start_point + line_length = np.linalg.norm(line_vector) + + # Normalize the line vector + line_unit_vector = line_vector / line_length + + # Calculate the perpendicular vector + perp_vector = np.array([-line_unit_vector[1], line_unit_vector[0]]) + perp_vector = perp_vector / np.linalg.norm(perp_vector) + + if hline_anchor is None: + most_inferior_point = np.min(contour[1]) + # move extreme 1 down 5 mm below inferior point and extreme 2 the + # same distance (so the angle stays the same) + down_distance = (extremes[1][1] - most_inferior_point) * 1.3 + start_point = extremes[0] + down_distance * perp_vector + end_point = extremes[1] + down_distance * perp_vector + + else: + # get closest point on line to hline_anchor + intersection = start_point + line_unit_vector * np.dot(hline_anchor - start_point, line_unit_vector) + # get distance closest point on line to hline_anchor + distance = np.linalg.norm(intersection - hline_anchor) + # move start and end point the same distance + start_point = extremes[0] + distance * perp_vector + end_point = extremes[1] + distance * perp_vector + + extremes = (start_point, end_point) + + # Calculate the line between the extreme points + start_point, end_point = extremes + line_vector = end_point - start_point + line_length = np.linalg.norm(line_vector) + + # Normalize the line vector + line_unit_vector = line_vector / line_length + + # Calculate the perpendicular vector + perp_vector = np.array([-line_unit_vector[1], line_unit_vector[0]]) + + # Calculate split points based on area weights + split_points = [] + for weight in area_weights: + # current_weight = np.sum(area_weights[:i]) + split_distance = weight * line_length + split_point = start_point + split_distance * line_unit_vector + split_points.append(split_point) + + # Split the contour at the calculated split points + split_contours = [] + split_contours.append(contour) + for split_point in split_points: + intersections = [] + for i in range(contour.shape[1] - 1): + segment_start = contour[:, i] + segment_end = contour[:, i + 1] + segment_vector = segment_end - segment_start + + # Check for intersection with the perpendicular line + matrix = np.array([segment_vector, -perp_vector]).T + if np.linalg.matrix_rank(matrix) < 2: + continue # Skip parallel lines + + # Solve for intersection + t, s = np.linalg.solve(matrix, split_point - segment_start) + if 0 <= t <= 1: + intersection_point = segment_start + t * segment_vector + intersections.append((i, intersection_point)) + + # Sort intersections by their position along the contour + # intersections.sort() + + # get the two intersections that have the highest y coordinate + intersections.sort(key=lambda x: x[1][1], reverse=True) + + # Create new contours by splitting at intersections + if intersections: + first_index, first_intersection = intersections[1] + second_index, second_intersection = intersections[0] + + if first_index > second_index: + first_index, second_index = second_index, first_index + first_intersection, second_intersection = second_intersection, first_intersection + + first_index += 1 + # second_index += 1 + + # start_to_cutoff = np.hstack((contour[:, :first_index], first_intersection[:, None], + # second_intersection[:, None], contour[:, second_index + 1:])) + start_to_cutoff = np.hstack( + (first_intersection[:, None], contour[:, first_index:second_index], second_intersection[:, None]) + ) + + + # connect first and second half + split_contours.append(start_to_cutoff) + else: + raise ValueError("No intersections found, this should not happen") + + if plot: + # make vline at every split point + split_points_vlines_start = (np.array(split_points) - perp_vector * 1).T + split_points_vlines_end = (np.array(split_points) + perp_vector * 1).T + + if oriented: + # make another vline at start point and end point, this time not + # perpendicular to line but perpendicular to x-axis + start_point_vline = np.array([start_point, np.array([start_point[0], start_point[1] + 8])]) + end_point_vline = np.array([end_point, np.array([end_point[0], end_point[1] + 8])]) + else: + start_point_vline = np.array([start_point, start_point - perp_vector * 8]) + end_point_vline = np.array([end_point, end_point - perp_vector * 8]) + + if plot_transform is not None: + split_contours = [plot_transform(split_contour) for split_contour in split_contours] + contour = plot_transform(contour) + extremes = [plot_transform(extreme[:, None]) for extreme in extremes] + split_points = [plot_transform(split_point[:, None]) for split_point in split_points] + split_points_vlines_start = plot_transform(split_points_vlines_start) + split_points_vlines_end = plot_transform(split_points_vlines_end) + start_point_vline = plot_transform(start_point_vline.T).T + end_point_vline = plot_transform(end_point_vline.T).T + + import matplotlib.pyplot as plt + + if ax is None: + SHOW = True + fig, ax = plt.subplots(1, 1, figsize=(8, 6)) + ax.axis("equal") + else: + SHOW = False + # pretty plot with areas filled in the polygon and overall area annotated + colors = plt.cm.Spectral(np.linspace(0.2, 0.8, len(split_contours))) + for color, split_contour in zip(colors, split_contours, strict=True): + ax.fill(split_contour[0], split_contour[1], alpha=0.5, color=color) + # ax.text(np.mean(split_contour[0]), np.mean(split_contour[1]), + # f'{area_out[i]:.2f}', color='black', fontsize=12) + # plot contour + ax.plot(contour[0], contour[1], "-", linewidth=2, color="grey") + # dashed line between start point & end point + ax.plot( + np.vstack((extremes[0][0], extremes[1][0])), + np.vstack((extremes[0][1], extremes[1][1])), + "--", + linewidth=2, + color="grey", + ) + # markers at every split point + for i in range(split_points_vlines_start.shape[1]): + ax.plot( + np.vstack((split_points_vlines_start[:, i][0], split_points_vlines_end[:, i][0])), + np.vstack((split_points_vlines_start[:, i][1], split_points_vlines_end[:, i][1])), + "k-", + linewidth=2, + ) + + ax.plot(start_point_vline[:, 0], start_point_vline[:, 1], "--", linewidth=2, color="grey") + ax.plot(end_point_vline[:, 0], end_point_vline[:, 1], "--", linewidth=2, color="grey") + # put text between split points + # add endpoints to split_points + split_points.insert(0, extremes[0]) + split_points.append(extremes[1]) + # convert area_weights into fraction of total line length + # e.g. area_weights=[1/6, 1/2, 2/3, 3/4] to ['1/6', '2/3', ...] + # cumulative difference + area_weights_diff = [] + area_weights_diff.append(area_weights[0]) + for i in range(1, len(area_weights)): + area_weights_diff.append(area_weights[i] - area_weights[i - 1]) + area_weights_diff.append(1 - area_weights[-1]) + + # area_weights_txt = area_weights_txt / area_weights_txt[-1] + from fractions import Fraction + + area_weights_txt = [ + Fraction(area_weights_diff[i]).limit_denominator(1000) for i in range(len(area_weights_diff)) + ] + + for i in range(len(split_points) - 1): + midpoint = np.mean([split_points[i], split_points[i + 1]], axis=0) + # ax.text(midpoint[0]-5, midpoint[1]-5, f'{area_out[i]:.2f}', color='black', fontsize=12) + ax.text( + midpoint[0], + midpoint[1] - 5, + f"{area_weights_txt[i]}", + color="black", + fontsize=11, + horizontalalignment="center", + ) + + # start point & end point + ax.plot(extremes[0][0], extremes[0][1], marker="o", markersize=8, color="black") + ax.plot(extremes[1][0], extremes[1][1], marker="o", markersize=8, color="black") + + # plot contour 0 point + # ax.scatter(contour[0,0], contour[1,0], color='red', s=100) + + ax.set_title("Split Contours") + # ax.set_xlabel('X') + # ax.set_ylabel('Y') + + # axis off + ax.axis("off") + if SHOW: + ax.axis("equal") + plt.show() + + return calc_subsegment_areas(split_contours), split_contours + + +def transform_to_acpc_standard( + contour_ras: Polygon2dType | Polygon3dType, + ac_pt_ras: Vector2d, + pc_pt_ras: Vector2d, +) -> tuple[Polygon2dType, Vector2d, Vector2d, Callable[[Polygon2dType], Polygon2dType]]: + """Transform contour coordinates to AC-PC standard space. + + Transforms the contour coordinates by: + 1. Translating AC point to origin. + 2. Rotating to align PC point with posterior direction. + 3. Scaling to maintain AC-PC distance. + + Parameters + ---------- + contour_ras : array of floats + Array of shape (2, N) or (3, N) containing contour points in RAS space. + ac_pt_ras : np.ndarray + Anterior commissure point coordinates in AS space. + pc_pt_ras : np.ndarray + Posterior commissure point coordinates in AS space. + + Returns + ------- + contour_acpc : np.ndarray + Transformed contour points in AC-PC space. + ac_pt_acpc : np.ndarray + AC point in AC-PC space (origin). + pc_pt_acpc : np.ndarray + PC point in AC-PC space. + rotate_back : callable + Function to transform points back to RAS space. + """ + # translate AC to the origin and PC to (0, ac_pc_dist) + translation_matrix = np.array([[1, 0, -ac_pt_ras[0]], [0, 1, -ac_pt_ras[1]], [0, 0, 1]]) + + ac_pc_vec: Vector2d = pc_pt_ras - ac_pt_ras + ac_pc_dist = np.linalg.norm(ac_pc_vec) + + posterior_vector: Vector2d = np.array([-ac_pc_dist, 0], dtype=float) + + # get angle between ac_pc_vec and posterior_vector + dot_product = np.dot(ac_pc_vec, posterior_vector) + norms_product = np.linalg.norm(ac_pc_vec) * np.linalg.norm(posterior_vector) + theta = np.arccos(dot_product / norms_product) + + # Determine the sign of the angle using cross product + cross_product = np.cross(ac_pc_vec, posterior_vector) + if cross_product < 0: + theta = -theta + + # create rotation matrix for theta + rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0], [np.sin(theta), np.cos(theta), 0], [0, 0, 1]]) + + # apply translation and rotation + if contour_ras.shape[0] == 2: + contour_ras_homogeneous = np.vstack([contour_ras, np.ones(contour_ras.shape[1])]) + else: + contour_ras_homogeneous = contour_ras + + contour_acpc: Polygon2dType = (rotation_matrix @ translation_matrix) @ contour_ras_homogeneous + contour_acpc = contour_acpc[:2, :] + + def rotate_back(x: Polygon2dType) -> Polygon2dType: + return (np.linalg.inv(rotation_matrix @ translation_matrix) @ np.vstack([x, np.ones(x.shape[1])]))[:2, :] + + return contour_acpc, np.array([0, 0], dtype=float), np.array([-ac_pc_dist, 0], dtype=float), rotate_back + + +def preprocess_cc( + cc_label_nib: nibabelImage, + paths_csv: "pd.DataFrame", + subj_id: str, +) -> tuple[Mask2d, Vector2d, Vector2d]: + """Preprocess corpus callosum mask and extract AC/PC coordinates. + + Parameters + ---------- + cc_label_nib : nibabel.Nifti1Image + NIfTI image containing corpus callosum segmentation. + paths_csv : pd.DataFrame + DataFrame containing AC and PC coordinates. + subj_id : str + Subject ID to look up in paths_csv. + + Returns + ------- + cc_mask : np.ndarray + Binary mask of corpus callosum. + AC_2d : np.ndarray + 2D coordinates of anterior commissure. + PC_2d : np.ndarray + 2D coordinates of posterior commissure. + + """ + #FIXME: this function is not used anywhere + _cc_mask: Mask3d = np.asarray(cc_label_nib.dataobj) == 192 + cc_mask: Mask2d = _cc_mask[_cc_mask.shape[0] // 2] + + posterior_commisure_center = paths_csv.loc[subj_id, "PC_center_r": "PC_center_s"].to_numpy().astype(float) + anterior_commisure_center = paths_csv.loc[subj_id, "AC_center_r": "AC_center_s"].to_numpy().astype(float) + + # adjust LR from label coordinates to orig_up coordinates + posterior_commisure_center[0] = 128 + anterior_commisure_center[0] = 128 + + # orientation I, A + # rotate image so anterior and posterior commisure are horizontal + ac_2d = anterior_commisure_center[1:] + pc_2d = posterior_commisure_center[1:] + + return cc_mask, ac_2d, pc_2d + + +def get_primary_eigenvector(contour_ras: Polygon2dType) -> tuple[Vector2d, Vector2d]: + """Calculate primary eigenvector of contour points using PCA. + + Computes the principal direction of the contour by: + 1. Centering the points + 2. Computing covariance matrix + 3. Finding eigenvectors + 4. Selecting primary direction + + Parameters + ---------- + contour_ras : np.ndarray + Array of shape (2, N) containing contour points in RAS space. + + Returns + ------- + pt0 : np.ndarray + Starting point for eigenvector line. + pt1 : np.ndarray + End point for eigenvector line. + + """ + # Center the data by subtracting mean + contour_mean = np.mean(contour_ras, axis=1, keepdims=True) + contour_centered = contour_ras - contour_mean + + # Calculate covariance matrix + cov_matrix = np.cov(contour_centered) + + # Get eigenvalues and eigenvectors using PCA + eigenvalues, eigenvectors = np.linalg.eigh(cov_matrix) + + # Sort in descending order + idx = eigenvalues.argsort()[::-1] + eigenvalues = eigenvalues[idx] + eigenvectors = eigenvectors[:, idx] + + # make first eigenvector unit length + primary_eigenvector = eigenvectors[:, 0] / np.linalg.norm(eigenvectors[:, 0]) + pt0 = np.mean(contour_ras, axis=1) + pt0 -= np.array([0, 5]) + pt1 = pt0 + primary_eigenvector * 100 + # plot mask with eigentvector + # import matplotlib.pyplot as plt + # fig, ax = plt.subplots(1,2,figsize=(10, 8)) + # ax[0].imshow(cc_mask, cmap='gray') + # # plot line between pt0 and pt1 + # ax[0].plot([pt0[0], pt1[0]], [pt0[1], pt1[1]], 'r-', linewidth=2) + # plt.show() + + return pt0, pt1 + diff --git a/CorpusCallosum/shape/thickness.py b/CorpusCallosum/shape/thickness.py new file mode 100644 index 000000000..0303d1fe3 --- /dev/null +++ b/CorpusCallosum/shape/thickness.py @@ -0,0 +1,455 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Literal, overload + +import numpy as np +import scipy.interpolate +from lapy import Solver, TriaMesh +from lapy.diffgeo import compute_rotated_f +from meshpy import triangle + +from CorpusCallosum.utils.types import ContourThickness, Points2dType +from FastSurferCNN.utils.common import suppress_stdout + + +def compute_curvature(path: Points2dType) -> np.ndarray[tuple[int], np.dtype[np.float_]]: + """Compute curvature by computing edge angles. + + Parameters + ---------- + path : np.ndarray + Array of shape (N, 2) containing path coordinates. + + Returns + ------- + np.ndarray + Array of angle differences between consecutive edges. + """ + # compute curvature by computing edge angles + edges = np.diff(path, axis=0) + angles = np.arctan2(edges[:, 1], edges[:, 0]) + # compute angle differences between consecutive edges + angle_diffs = np.diff(angles) + # wrap angles to [-pi, pi] + angle_diffs = np.mod(angle_diffs + np.pi, 2 * np.pi) - np.pi + return angle_diffs + + +@overload +def convert_to_ras(contour: np.ndarray, vox2ras_matrix: np.ndarray, get_parameters: Literal[False] = False) \ + -> np.ndarray: ... + +@overload +def convert_to_ras(contour: np.ndarray, vox2ras_matrix: np.ndarray, get_parameters: Literal[True]) \ + -> tuple[np.ndarray, bool, bool, bool]: ... + + +def convert_to_ras( + contour: np.ndarray, + vox2ras_matrix: np.ndarray, + return_parameters: bool = False +): + """Convert contour coordinates from voxel space to RAS space. + + Parameters + ---------- + contour : np.ndarray + Array of shape (2, N) or (3, N) containing contour coordinates. + vox2ras_matrix : np.ndarray + 4x4 voxel to RAS transformation matrix. + return_parameters : bool, default=False + If True, return additional transformation parameters (see below). + + Returns + ------- + contour : np.ndarray + Transformed contour coordinates of shape (3, N). + anterior_reversed : bool + Only if return_parameters is True, whether anterior axis was reversed. + superior_reversed : bool + Only if return_parameters is True, whether superior axis was reversed. + swap_axes : bool + Only if return_parameters is True, whether axes were swapped. + """ + # converting to AS (no left-right dimension), out of plane movement is ignored, + # so we only do scaling, axes swapping and flipping - no rotation + # translation is ignored + if contour.shape[0] == 2: + # get only axis swaps from the rotation part of the vox2ras matrix + axis_swaps = np.round(vox2ras_matrix[:3, :3], 0) + permutation = np.argwhere(axis_swaps != 0)[:, 1] + assert len(permutation) == 3 + + idx_superior = np.argwhere(permutation == 2) + idx_anterior = np.argwhere(permutation == 1) + + # swap axes if indicated from vox2ras + if swap_axes := idx_anterior > idx_superior: + # swap anterior and superior + contour = contour[[1, 0]] + + # determine if axis were reversed + superior_reversed = np.any(axis_swaps[2, :] == -1) + anterior_reversed = np.any(axis_swaps[1, :] == -1) + + # flip axes if necessary + if superior_reversed: + contour[1] = -contour[1] + if anterior_reversed: + contour[0] = -contour[0] + + # get scaling by getting length of three column vectors + scaling = np.linalg.norm(vox2ras_matrix[:3, :3], axis=0) + + # voxel * vox_size = mm + contour = (contour.T * scaling[1:]).T + + # append a 0-R coordinate + contour = np.concatenate([np.zeros((1, contour.shape[1])), contour], axis=0) + + if return_parameters: + return contour, anterior_reversed, superior_reversed, swap_axes + else: + return contour + + # Add a third dimension (z) with 0 and a fourth dimension (homogeneous coordinate) with 1 + elif contour.shape[0] == 3: + contour_homogeneous = np.vstack([contour, np.ones(contour.shape[1])]) + + # Apply the transformation + contour = (vox2ras_matrix @ contour_homogeneous)[:3, :] + return contour + else: + raise ValueError("Invalid shape of contour") + + +def set_contour_zero_idx(contour, idx, anterior_endpoint_idx, posterior_endpoint_idx): + """Roll contour points to set a new zero index, while keeping track of CC endpoints. + + Parameters + ---------- + contour : np.ndarray + Array of contour points. + idx : int + New zero index. + anterior_endpoint_idx : int + Index of anterior endpoint. + posterior_endpoint_idx : int + Index of posterior endpoint. + + Returns + ------- + contour : np.ndarray + Rolled contour points. + anterior_endpoint_idx : int + Updated anterior endpoint index. + posterior_endpoint_idx : int + Updated posterior endpoint index. +""" + contour = np.roll(contour, -idx, axis=0) + anterior_endpoint_idx = (anterior_endpoint_idx - idx) % contour.shape[0] + posterior_endpoint_idx = (posterior_endpoint_idx - idx) % contour.shape[0] + return contour, anterior_endpoint_idx, posterior_endpoint_idx + + +def find_closest_edge(point, contour): + """Find the index of the edge closest to the given point. + + Parameters + ---------- + point : np.ndarray + 2D point coordinates. + contour : np.ndarray + Array of shape (N, 2) containing contour points. + + Returns + ------- + int + Index of the closest edge. + """ + edges_start = contour[:-1, :2] # N-1 x 2 + edges_end = contour[1:, :2] # N-1 x 2 + edges_vec = edges_end - edges_start # N-1 x 2 + + # Calculate projection coefficient for all edges at once + # (p-a)·(b-a) / |b-a|² + edge_lengths_sq = np.sum(edges_vec * edges_vec, axis=1) + # Avoid division by zero for degenerate edges + valid_edges = edge_lengths_sq > 1e-10 + t = np.zeros(len(edges_start)) + t[valid_edges] = ( + np.sum((point - edges_start[valid_edges]) * edges_vec[valid_edges], axis=1) + / edge_lengths_sq[valid_edges] + ) + t = np.clip(t, 0, 1) # Clamp to edge endpoints + + # Get closest points on all edges + closest_points = edges_start + t[:, None] * edges_vec + + # Calculate distances to all edges + distances = np.linalg.norm(point - closest_points, axis=1) + + # Return index of closest edge + return np.argmin(distances) + + +@overload +def insert_point_with_thickness( + contour_in_as_space: np.ndarray, + contour_thickness: np.ndarray, + point: np.ndarray, + thickness_value: float, + return_index: Literal[False] = False, +) -> tuple[np.ndarray, np.ndarray]: ... + + +@overload +def insert_point_with_thickness( + contour_in_as_space: np.ndarray, + contour_thickness: np.ndarray, + point: np.ndarray, + thickness_value: float, + return_index: Literal[True], +) -> tuple[np.ndarray, np.ndarray, int] | list[np.ndarray, np.ndarray]: + ... + + +def insert_point_with_thickness( + contour_in_as_space: np.ndarray, + contour_thickness: np.ndarray, + point: np.ndarray, + thickness_value: float, + return_index: bool = False +) -> tuple[np.ndarray, np.ndarray, int] | tuple[np.ndarray, np.ndarray]: + """Inserts a point and its thickness value into the contour. + + Parameters + ---------- + contour_in_as_space : np.ndarray + Array of coordinates of the contour in AS space, shape (N, 2). + contour_thickness : np.ndarray + Array of thickness values of the contour, shape (N,). + point : np.ndarray + 2D point to insert, shape (2,). + thickness_value : float + Thickness value corresponding to the point. + return_index : bool, default=False + If True, return the index where point was inserted, by default False. + + Returns + ------- + contour_in_as_space : np.ndarray + Updated contour of shape (N+1, 2). + contour_thickness : np.ndarray + Updated thickness values of shape (N+1,). + insertion_index : int + The index, where the point was inserted (only if return_index is True). + """ + # Find closest edge for the point + edge_idx = find_closest_edge(point, contour_in_as_space) + + # Insert point between edge endpoints + contour_in_as_space = np.insert(contour_in_as_space, edge_idx + 1, point, axis=0) + contour_thickness = np.insert(contour_thickness, edge_idx + 1, thickness_value) + + if return_index: + return contour_in_as_space, contour_thickness, edge_idx + 1 + else: + return contour_in_as_space, contour_thickness + + +def make_mesh_from_contour( + contour_2d: np.ndarray, + max_volume: float = 0.5, + min_angle: float = 25, + verbose: bool = False +) -> tuple[np.ndarray, np.ndarray]: + """Create a triangular mesh from a 2D contour. + + Parameters + ---------- + contour_2d : np.ndarray + Array of shape (N, 2) containing contour points. + max_volume : float, optional + Maximum triangle area, by default 0.5. + min_angle : float, optional + Minimum angle in triangles (degrees), by default 25. + verbose : bool, optional + Whether to print mesh generation info, by default False. + + Returns + ------- + tuple[np.ndarray, np.ndarray] + - mesh_points : Array of shape (M, 2) containing mesh vertices. + - mesh_trias : Array of shape (K, 3) containing triangle indices. + + Notes + ----- + Uses meshpy.triangle to create a constrained Delaunay triangulation + of the contour. The contour must not have duplicate points. + """ + + facets = np.vstack((np.arange(len(contour_2d)), ((np.arange(len(contour_2d)) + 1) % len(contour_2d)))).T + + # use meshpy to create mesh + info = triangle.MeshInfo() + info.set_points(contour_2d) + info.set_facets(facets) + # NOTE: crashes if contour has duplicate points !! + mesh = triangle.build(info, max_volume=max_volume, min_angle=min_angle, verbose=verbose) + + mesh_points = np.array(mesh.points) + mesh_trias = np.array(mesh.elements) + + return mesh_points, mesh_trias + + +def cc_thickness( + contour_2d: Points2dType, + endpoint_idx: tuple[int, int], + n_points: int = 100, +) -> tuple[float, float, float, Points2dType , list[Points2dType], ContourThickness, tuple[int, int]]: + """Calculate corpus callosum thickness using Laplace equation. + + Parameters + ---------- + contour_2d : np.ndarray + Array of shape (N, 2) containing contour points. + endpoint_idx : pair of ints + Indices of anterior and posterior endpoints in contour. + n_points : int, default=100 + Number of points for thickness measurement. + + Returns + ------- + midline_length : float + Total length of the midline. + thickness : float + Mean thickness across all level paths. + curvature : float + Mean absolute curvature in degrees. + midline_equidistant : np.ndarray + Equidistant points along the midline in same space as contour2d of shape (N, 2). + levelpaths : list[np.ndarray] + Level paths for thickness measurement in same space as contour2d, each of shape (N, 2). + contour_with_thickness : np.ndarray + Contour coordinates with thickness information in same space as contour2d of shape (N+2, 3). + endpoint_indices : pair of ints + Pair of updated indices of anterior and posterior endpoint. + + Notes + ----- + Uses the Laplace equation to compute thickness by: + 1. Creating a triangular mesh from the contour + 2. Setting boundary conditions (0 at endpoints, ±1 on sides) + 3. Solving Laplace equation to get level sets + 4. Computing thickness along level sets + """ + anterior_endpoint_idx, posterior_endpoint_idx = endpoint_idx + + # standardize contour indices to start at anterior_endpoint_idx, to get consistent levelpath directions + contour_2d, anterior_endpoint_idx, posterior_endpoint_idx = set_contour_zero_idx( + contour_2d, anterior_endpoint_idx, anterior_endpoint_idx, posterior_endpoint_idx, + ) + + mesh_points_contour_space, mesh_trias = make_mesh_from_contour(contour_2d) + + # make points 3D by appending z=0, asz space therefore is the contour space (usually AS space) with a zero z-dim + mesh_points_asz = np.append(mesh_points_contour_space, np.zeros((mesh_points_contour_space.shape[0], 1)), axis=1) + + # compute poisson + with suppress_stdout(): + tria_asz = TriaMesh(mesh_points_asz, mesh_trias) + # extract boundary curve + bdr = np.array(tria_asz.boundary_loops()[0]) + + # find index of endpoints in bdr list + iidx1 = np.where(bdr == anterior_endpoint_idx)[0][0] + iidx2 = np.where(bdr == posterior_endpoint_idx)[0][0] + + # create boundary condition (0 at endpoints, -1 on one side, 1 on the other): + if iidx1 > iidx2: + iidx1, iidx2 = iidx2, iidx1 + dcond = np.ones(bdr.shape) + dcond[iidx1] = 0 + dcond[iidx2] = 0 + dcond[iidx1 + 1 : iidx2] = -1 + + # Extract path + fem = Solver(tria_asz) + vfunc = fem.poisson(0, (bdr, dcond)) + midline_length: float + midline_equidistant_asz, midline_length = tria_asz.level_path(vfunc, level=0., n_points=n_points + 2) + midline_equidistant_contour_space: np.ndarray = midline_equidistant_asz[:, :2] + + gf = compute_rotated_f(tria_asz, vfunc) + + # interpolate midline to get levels to evaluate + level_of_rotated_laplace_contour_space = scipy.interpolate.griddata( + tria_asz.v[:, 0:2], gf, midline_equidistant_asz[:, 0:2], method="cubic", + ) + + # get levels to evaluate + levelpaths_contour_space: list[Points2dType] = [] + levelpath_lengths = [] + levelpath_tria_idx = [] + + # now, on the rotated laplace function, sample equally spaced (on midline: level_of_rotated_laplace) levelpaths + contour_thickness = np.full(contour_2d.shape[0], np.nan) + for current_level in level_of_rotated_laplace_contour_space[1:-1]: + # levelpath starts at index zero + levelpath_asz, lvlpath_length, tria_idx = tria_asz.level_path(gf, current_level, get_tria_idx=True) + + levelpaths_contour_space.append(levelpath_asz[:, :2]) + levelpath_lengths.append(lvlpath_length) + levelpath_tria_idx.append(tria_idx) + + levelpath_start = levelpath_asz[0, :2] + levelpath_end = levelpath_asz[-1, :2] + + contour_2d, contour_thickness, inserted_idx_start = insert_point_with_thickness( + contour_2d, contour_thickness, levelpath_start, lvlpath_length, return_index=True, + ) + # keep track of start index + if inserted_idx_start <= anterior_endpoint_idx: + anterior_endpoint_idx += 1 + if inserted_idx_start >= posterior_endpoint_idx: + posterior_endpoint_idx += 1 + + contour_2d, contour_thickness, inserted_idx_end = insert_point_with_thickness( + contour_2d, contour_thickness, levelpath_end, lvlpath_length, return_index=True, + ) + # keep track of end index + if inserted_idx_end <= anterior_endpoint_idx: + anterior_endpoint_idx += 1 + if inserted_idx_end >= posterior_endpoint_idx: + posterior_endpoint_idx += 1 + + contour_2d_with_thickness = np.concatenate([contour_2d, contour_thickness[:, None]], axis=1) + + # get curvature of path3d_resampled + curvature = compute_curvature(midline_equidistant_contour_space) + mean_curvature: float = np.abs(np.degrees(np.mean(curvature))).item() / len(curvature) + mean_thickness: float = np.mean(levelpath_lengths).item() + endpoints: tuple[int, int] = (anterior_endpoint_idx, posterior_endpoint_idx) + + return ( + midline_length, + mean_thickness, + mean_curvature, + midline_equidistant_contour_space, + levelpaths_contour_space, + contour_2d_with_thickness, + endpoints, + ) diff --git a/CorpusCallosum/transforms/__init__.py b/CorpusCallosum/transforms/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/CorpusCallosum/transforms/localization.py b/CorpusCallosum/transforms/localization.py new file mode 100644 index 000000000..e129fc820 --- /dev/null +++ b/CorpusCallosum/transforms/localization.py @@ -0,0 +1,153 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from logging import getLogger + +import numpy as np +import torch +from monai.transforms import MapTransform, RandomizableTransform + + +class CropAroundACPCFixedSize(RandomizableTransform, MapTransform): + """Crop image around AC-PC points with fixed size. + + A transform that crops the input image around the midpoint between + AC and PC points with a fixed size window and optional random translation. + + Parameters + ---------- + keys : list[str] + Keys of the data dictionary to apply the transform to. + fixed_size : tuple[int, int] + Fixed size of the crop window (width, height). + allow_missing_keys : bool, optional + Whether to allow missing keys in the data dictionary, by default False. + random_translate : int, default=0 + Maximum random translation in voxels. + + Raises + ------ + ValueError + If the crop boundaries extend outside the image dimensions. + + Notes + ----- + The transform expects the following keys in the data dictionary: + + - AC_center : np.ndarray + Coordinates of anterior commissure + - PC_center : np.ndarray + Coordinates of posterior commissure + - image : np.ndarray + Input image to crop + + """ + + def __init__( + self, + keys: list[str], + fixed_size: tuple[int, int], + allow_missing_keys: bool = False, + random_translate: int = 0, + ) -> None: + MapTransform.__init__(self, keys, allow_missing_keys) + RandomizableTransform.__init__(self) + self.random_translate = random_translate + self.fixed_size = fixed_size + + def __call__(self, data: dict) -> dict: + """Apply the 2D crop transform to the data. + + Parameters + ---------- + data : dict + Dictionary containing the data to transform AND keys AC_center and PC_center, each of shape (B, 2). + + Returns + ------- + dict + Transformed data dictionary with cropped images and updated coordinates. + Also includes crop boundary information: + - crop_left : list[int] + - crop_right : list[int] + - crop_top : list[int] + - crop_bottom : list[int] + + Raises + ------ + ValueError + If crop boundaries extend outside the image dimensions + """ + d = dict(data) + + expected_keys = {"PC_center", "AC_center"} | set(self.keys) if not self.allow_missing_keys else {} + + if expected_keys & set(d.keys()) != expected_keys: + raise ValueError(f"The following keys are missing in the data dictionary: {expected_keys - set(d.keys())}!") + + if any(d[k].ndim != 2 or d[k].shape[1] != 2 for k in ["PC_center", "AC_center"]): + raise ValueError("Shape of AC_center or PC_center incorrect, must be (B, 2)!") + + if any(d[k].ndim != 4 for k in self.keys if k in d.keys()): + raise ValueError(f"At least one key of {self.keys} does not have a 4-dimensional tensor.") + + # calculate center point between AC and PC + center_point = ((d['AC_center'] + d['PC_center']) / 2).astype(int) + + # Calculate voxel padding based on mm padding + voxel_padding = np.asarray(self.fixed_size) // 2 + + existing_keys = set(self.keys) & set(d.keys()) + if len(existing_keys) == 0: + getLogger(__name__).warning(f"None of the keys in {self.keys} are present in the data dictionary!") + return d + + first_key = tuple(existing_keys)[0] + + # Calculate crop boundaries with padding and random translation + crops = center_point - voxel_padding + + # Add random translation if specified + if self.random_translate > 0: + crops += np.random.randint( + -self.random_translate, + self.random_translate + 1, + size=(d[first_key].shape[0], 2), + ) + + # Ensure crop boundaries are within image + img_shape = np.asarray(d[first_key].shape[2:]) # Get spatial dimensions + if any(np.any(img_shape != d[k].shape[2:]) for k in self.keys if k in d.keys()): + raise ValueError(f"At least one key of {self.keys} does not have the expected shape.") + + patch_size_with_batch_dim = np.asarray(self.fixed_size)[None] + crops = np.maximum(0, np.minimum(img_shape, crops + patch_size_with_batch_dim) - patch_size_with_batch_dim) + d["crop_left"], d["crop_top"] = crops.T.tolist() + d["crop_right"], d["crop_bottom"] = (crops_end := crops + patch_size_with_batch_dim).T.tolist() + + # raise error if crop boundaries are out of image + if np.any(crops < 0) or np.any(crops_end > np.asarray([d[first_key].shape[2:]])): + raise ValueError("Crop boundaries are out of image") + + # Apply crop to image + for key in self.keys: + if key not in d.keys() and self.allow_missing_keys: + continue + arr = [v[:, cl:cr, ct:cb] for v, cl, ct, cr, cb in zip(d[key], *crops.T, *crops_end.T, strict=True)] + d[key] = torch.stack(arr, dim=0) if torch.is_tensor(arr[0]) else np.stack(arr, axis=0) + + # Update point coordinates relative to cropped image + d["PC_center"] = d["PC_center"] - crops + d["AC_center"] = d["AC_center"] - crops + return d diff --git a/CorpusCallosum/transforms/segmentation.py b/CorpusCallosum/transforms/segmentation.py new file mode 100644 index 000000000..2b54b450f --- /dev/null +++ b/CorpusCallosum/transforms/segmentation.py @@ -0,0 +1,180 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import Literal + +import numpy as np +from monai.transforms import MapTransform, RandomizableTransform + + +class CropAroundACPC(RandomizableTransform, MapTransform): + """Crop image around anterior and posterior commissure points. + + A transform that crops the input image around the AC and PC points with + optional padding and random translation. + + Parameters + ---------- + keys : list[str] + Keys of the data dictionary to apply the transform to. + allow_missing_keys : bool, default=False + Whether to allow missing keys in the data dictionary. + padding_mm : float, default=10.0 + Padding around AC-PC region in millimeters. + random_translate : float, default=0 + Maximum random translation in voxels, off by default. + + Notes + ----- + The transform expects the following keys in the data dictionary: + + - AC_center : np.ndarray + Coordinates of anterior commissure + - PC_center : np.ndarray + Coordinates of posterior commissure + - res : float + Voxel resolution in mm + + """ + + def __init__(self, keys: list[str], allow_missing_keys: bool = False, + padding_mm: float = 10, random_translate: float = 0) -> None: + MapTransform.__init__(self, keys, allow_missing_keys) + RandomizableTransform.__init__(self, prob=1, do_transform=True) + self.padding_mm = padding_mm + self.random_translate = random_translate + + def __call__(self, data: dict) -> dict: + """Apply the transform to the data. + + Parameters + ---------- + data : dict + Dictionary containing the data to transform. + + Returns + ------- + dict + Transformed data dictionary. + """ + d = dict(data) + + if "AC_center_original" not in d: + d["AC_center_original"] = d["AC_center"].copy() + if "PC_center_original" not in d: + d["PC_center_original"] = d["PC_center"].copy() + + if self.random_translate > 0: + random_translate = np.random.randint(-self.random_translate, self.random_translate, size=2) + else: + random_translate = (0,0,0) + + pc_center = d["PC_center"] + ac_center = d["AC_center"] + + ac_pc = np.stack([ac_center, pc_center], axis=0) + + ac_pc_bottomleft = np.min(ac_pc, axis=0).astype(int) + ac_pc_topright = np.max(ac_pc, axis=0).astype(int) + + VoxPadType = np.ndarray[tuple[Literal[2]], np.dtype[np.int_]] + voxel_padding: VoxPadType = np.round(self.padding_mm / d["res"]).astype(int) + + crop_left = ac_pc_bottomleft[1] - int(voxel_padding[0] * 1.5) + random_translate[0] + crop_right = ac_pc_topright[1] + voxel_padding[0] // 2 + random_translate[0] + crop_top = ac_pc_bottomleft[2] - voxel_padding[1] + random_translate[1] + crop_bottom = ac_pc_topright[2] + voxel_padding[1] + random_translate[1] + + keys_to_process = [key for key in self.keys if key in d.keys()] + + if not self.allow_missing_keys and set(keys_to_process) != set(self.keys): + raise ValueError("Some keys are missing in the data dictionary.") + + if len(keys_to_process) == 0: + logging.getLogger(__name__).warning("No keys to process.") + return d + + first_key = keys_to_process[0] + d["to_pad"] = crop_left, d[first_key].shape[2] - crop_right, crop_top, d[first_key].shape[3] - crop_bottom + + for key in keys_to_process: + d[key] = d[key][:, :, crop_left:crop_right, crop_top:crop_bottom] + + return d + + +class CropAroundACPCtrack(CropAroundACPC): + """Crop image around AC-PC points and update their coordinates. + + Extends CropAroundACPC to also adjust the AC and PC center coordinates + after cropping to maintain their correct positions in the cropped image. + + Parameters + ---------- + keys : list[str] + Keys of the data dictionary to apply the transform to. + allow_missing_keys : bool, optional + Whether to allow missing keys in the data dictionary, by default False. + padding_mm : float, optional + Padding around AC-PC region in millimeters, by default 10. + random_translate : float, optional + Maximum random translation in voxels, by default 0. + + Notes + ----- + The transform expects the following keys in the data dictionary: + + - AC_center : np.ndarray + Coordinates of anterior commissure + - PC_center : np.ndarray + Coordinates of posterior commissure + - AC_center_original : np.ndarray + Original coordinates of anterior commissure + - PC_center_original : np.ndarray + Original coordinates of posterior commissure + + """ + + def __call__(self, data: dict) -> dict: + """Apply the transform to the data. + + Parameters + ---------- + data : dict + Dictionary containing the data to transform. + + Returns + ------- + dict + Transformed data dictionary with updated AC and PC coordinates. + """ + + + # First call parent class to get cropped image + d = super().__call__(data) + + # Get the crop coordinates that were used + pad_left, pad_right, pad_top, pad_bottom = d["to_pad"] + + # Adjust AC and PC center coordinates based on cropping + if "AC_center" in d: + d["AC_center"][1] = d["AC_center_original"][1] - pad_left.item() + d["AC_center"][2] = d["AC_center_original"][2] - pad_top.item() + + if "PC_center" in d: + d["PC_center"][1] = d["PC_center_original"][1] - pad_left.item() + d["PC_center"][2] = d["PC_center_original"][2] - pad_top.item() + + return d + diff --git a/CorpusCallosum/utils/__init__.py b/CorpusCallosum/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/CorpusCallosum/utils/checkpoint.py b/CorpusCallosum/utils/checkpoint.py new file mode 100644 index 000000000..355542bd1 --- /dev/null +++ b/CorpusCallosum/utils/checkpoint.py @@ -0,0 +1,17 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from FastSurferCNN.utils.parser_defaults import FASTSURFER_ROOT + +YAML_DEFAULT = FASTSURFER_ROOT / "CorpusCallosum/config/checkpoint_paths.yaml" diff --git a/CorpusCallosum/utils/mapping_helpers.py b/CorpusCallosum/utils/mapping_helpers.py new file mode 100644 index 000000000..0bd227bbc --- /dev/null +++ b/CorpusCallosum/utils/mapping_helpers.py @@ -0,0 +1,454 @@ +from pathlib import Path + +import nibabel as nib +import numpy as np +import SimpleITK as sitk +from numpy import typing as npt +from scipy.ndimage import affine_transform +from typing_extensions import overload + +from CorpusCallosum.data.constants import CC_LABEL, FORNIX_LABEL +from CorpusCallosum.utils.types import Polygon3dType +from FastSurferCNN.utils import ( + AffineMatrix4x4, + Image2d, + Image3d, + RotationMatrix3x3, + Shape3d, + Vector2d, + Vector3d, + logging, + nibabelImage, Image4d, +) +from FastSurferCNN.utils.parallel import thread_executor + +logger = logging.get_logger(__name__) + + +def make_midplane_affine( + orig_affine: AffineMatrix4x4, + slices_to_analyze: int = 1, + offset: int = 4, + ) -> AffineMatrix4x4: + """Create affine transformation matrix for midplane slices. + + Parameters + ---------- + orig_affine : AffineMatrix4x4 + Original image affine matrix (4x4). + slices_to_analyze : int, default=1 + Number of slices to analyze around midplane. + offset : int, default=4 + Additional offset in x direction. + + Returns + ------- + AffineMatrix4x4 + 4x4 affine matrix for midplane slices. + """ + # Create translation matrix to center on midplane + orig_to_seg = np.eye(4) + orig_to_seg[0, 3] = -256 // 2 + slices_to_analyze // 2 + offset + + # Combine with original affine + seg_affine = orig_affine @ np.linalg.inv(orig_to_seg) + + return seg_affine + + +def correct_nodding(ac_pt: Vector2d, pc_pt: Vector2d) -> RotationMatrix3x3: + """Calculate rotation matrix to correct head nodding. + + Calculates rotation matrix to align AC-PC line with posterior direction, + correcting for head nodding based on AC-PC line orientation. + + Parameters + ---------- + ac_pt : Vector2d + 2D coordinates of the anterior commissure point. + pc_pt : Vector2d + 2D coordinates of the posterior commissure point. + + Returns + ------- + RotationMatrix + 3x3 rotation matrix to align AC-PC line with posterior direction. + """ + ac_pc_vec = pc_pt - ac_pt + ac_pc_dist = np.linalg.norm(ac_pc_vec) + + posterior_vector = np.array([0, -ac_pc_dist]) + + # get angle between ac_pc_vec and posterior_vector + dot_product = np.dot(ac_pc_vec, posterior_vector) + norms_product = np.linalg.norm(ac_pc_vec) * np.linalg.norm(posterior_vector) + theta = np.arccos(dot_product / norms_product) + + # Determine the sign of the angle using cross product + cross_product = np.cross(ac_pc_vec, posterior_vector) + if cross_product < 0: + theta = -theta + + # create rotation matrix for theta + rotation_matrix: RotationMatrix3x3 = np.array( + [ + [np.cos(theta), -np.sin(theta), 0], + [np.sin(theta), np.cos(theta), 0], + [0, 0, 1], + ] + ) + + return rotation_matrix + + +@overload +def apply_transform_to_pt(pts: Vector3d, T: AffineMatrix4x4, inv: bool = False) -> Vector3d: ... + +@overload +def apply_transform_to_pt(pts: Polygon3dType, T: AffineMatrix4x4, inv: bool = False) -> Polygon3dType: ... + +def apply_transform_to_pt(pts: Vector3d | Polygon3dType, T: AffineMatrix4x4, inv: bool = False): + """Apply homogeneous transformation matrix to points. + + Parameters + ---------- + pts : np.ndarray + Point coordinates to transform, shape (3,) or (3, N). + T : np.ndarray + 4x4 homogeneous transformation matrix. + inv : bool, default=False + If True, applies inverse of transformation. + + Returns + ------- + np.ndarray + Transformed point coordinates, shape (3,) or (3, N). + """ + if inv: + T = np.linalg.inv(T) + + if pts.ndim == 1: + return (T @ np.hstack((pts, 1)))[:3] + else: + return (T @ np.concatenate([pts, np.ones((1, pts.shape[1]))]))[:3] + + +def calc_mapping_to_standard_space( + orig: "nib.Nifti1Image", + ac_coords_3d: Vector3d, + pc_coords_3d: Vector3d, + orig_fsaverage_vox2vox: AffineMatrix4x4, +) -> tuple[AffineMatrix4x4, Vector3d, Vector3d, Vector3d, Vector3d]: + """Get transformations to map image to standard space. + + Parameters + ---------- + orig : nib.Nifti1Image + Original image. + ac_coords_3d : np.ndarray + AC coordinates in 3D space. + pc_coords_3d : np.ndarray + PC coordinates in 3D space. + orig_fsaverage_vox2vox : AffineMatrix4x4 + Transformation matrix from original to fsaverage space. + + Returns + ------- + upright_volume : np.ndarray + Upright transformed volume. + standardized_volume : np.ndarray + Volume in standard space. + ac_coords_standardized : np.ndarray + AC coordinates in standard space. + pc_coords_standardized : np.ndarray + PC coordinates in standard space. + standardized_affine : np.ndarray + Affine matrix for standard space. + """ + image_center = np.array(orig.shape) / 2 + + # correct nodding + nod_correct_2d = correct_nodding(ac_coords_3d[1:3], pc_coords_3d[1:3]) + + # convert 2D nodding correction to 3D transformation matrix + nod_correct_3d: AffineMatrix4x4 = np.eye(4, dtype=float) + nod_correct_3d[1:3, 1:3] = nod_correct_2d[:2, :2] # Copy rotation part to y,z axes + # Copy translation part to y,z axes (usually no translation) + nod_correct_3d[1:3, 3] = nod_correct_2d[:2, 2] + + ac_coords_after_nodding: Vector3d = apply_transform_to_pt( + ac_coords_3d, nod_correct_3d, inv=False, + ) + pc_coords_after_nodding: Vector3d = apply_transform_to_pt( + pc_coords_3d, nod_correct_3d, inv=False, + ) + + ac_to_center_translation: AffineMatrix4x4 = np.eye(4, dtype=float) + ac_to_center_translation[:3, 3] = image_center - ac_coords_after_nodding + + # correct nodding + ac_coords_standardized: Vector3d = apply_transform_to_pt( + ac_coords_after_nodding, ac_to_center_translation, inv=False, + ) + pc_coords_standardized: Vector3d = apply_transform_to_pt( + pc_coords_after_nodding, ac_to_center_translation, inv=False, + ) + + standardized_to_orig_vox2vox: AffineMatrix4x4 = ( + np.linalg.inv(orig_fsaverage_vox2vox) + @ np.linalg.inv(nod_correct_3d) + @ np.linalg.inv(ac_to_center_translation) + ) + + # calculate ac & pc in space of mri input image + ac_coords_orig: Vector3d = apply_transform_to_pt( + ac_coords_standardized, standardized_to_orig_vox2vox, inv=False, + ) + pc_coords_orig: Vector3d = apply_transform_to_pt( + pc_coords_standardized, standardized_to_orig_vox2vox, inv=False, + ) + #FIXME: incorrect docstring + return standardized_to_orig_vox2vox, ac_coords_standardized, pc_coords_standardized, ac_coords_orig, pc_coords_orig + + +def apply_transform_to_volume( + orig_image: nibabelImage, + vox2vox: AffineMatrix4x4, + affine: AffineMatrix4x4, + header: nib.freesurfer.mghformat.MGHHeader | None = None, + output_path: str | Path | None = None, + output_size: np.ndarray | None = None, + order: int = 1 +) -> npt.NDArray[float]: + """Apply transformation to a volume and save the result. + + Parameters + ---------- + orig_image : nibabelImage + Input volume. + vox2vox : np.ndarray + Transformation matrix to apply to the data, this is from input-to-output space. + affine : AffineMatrix4x4, optional + The vox2ras matrix of the output image, only relevant if output_path is given. + header : nibabelHeader, optional + Header for the output image, only relevant if output_path is given, if None will default to orig_image header. + output_path : str or Path, optional + If output_path is provided, saves the result under this path. + output_size : np.ndarray, optional + Size of output volume, uses input size by default `None`. + order : int, default=1 + Order of interpolation. + + Returns + ------- + npt.NDArray[float] + Transformed volume data. + + Notes + ----- + Uses `scipy.ndimage.affine_transform` for the transformation, and inverts vox2vox internally as required by + `affine_transform`. + """ + if output_size is None: + output_size = np.array(orig_image.shape) + if header is None: + header = orig_image.header + # transform / resample the volume with vox2vox, note this needs to be the inverse of input2output vox2vox! + # affine_transform definition is: input_coord = matrix @ output_coord + offset ( == MATRIX_HOM @ output_coord_hom) + # --> output_coord = inv(matrix) @ (input_coord - offset) ( == inv(MATRIX_HOM) @ input_coord_hom) + resampled = affine_transform(orig_image.get_fdata(), np.linalg.inv(vox2vox), output_shape=output_size, order=order) + if output_path is not None: + logger.info(f"Saving transformed volume to {output_path}") + nib.save(nib.MGHImage(resampled.astype(orig_image.get_data_dtype()), affine, header), output_path) + return resampled + + +def make_affine(simpleITKImage: sitk.Image) -> AffineMatrix4x4: + """Create an affine transformation matrix from a SimpleITK image. + + Parameters + ---------- + simpleITKImage : sitk.Image + Input SimpleITK image. + + Returns + ------- + np.ndarray + 4x4 affine transformation matrix in RAS coordinates. + + Notes + ----- + The function: + 1. Gets affine transform in LPS coordinates + 2. Converts to RAS coordinates to match nibabel + 3. Returns the final 4x4 transformation matrix + """ + # get affine transform in LPS + c = [simpleITKImage.TransformContinuousIndexToPhysicalPoint(p) for p in np.eye(4)[:, :3]] + c = np.array(c) + affine = np.concatenate( + [np.concatenate([c[0:3] - c[3:], c[3:]], axis=0), [[0.0], [0.0], [0.0], [1.0]]], + axis=1, + ) + affine = np.transpose(affine) + # convert to RAS to match nibabel + affine = np.matmul(np.diag([-1.0, -1.0, 1.0, 1.0]), affine) + return affine + + +@overload +def map_softlabels_to_orig( + cc_fn_softlabels: Image4d, + orig: nibabelImage, + orig2slab_vox2vox: AffineMatrix4x4, + cc_subseg_midslice: None = None, + orig2midslice_vox2vox: None = None, + orig_space_segmentation_path: str | Path | None = None, +) -> np.ndarray[Shape3d, np.dtype[np.int_]]: ... + + +@overload +def map_softlabels_to_orig( + cc_fn_softlabels: Image4d, + orig: nibabelImage, + orig2slab_vox2vox: AffineMatrix4x4, + cc_subseg_midslice: Image2d, + orig2midslice_vox2vox: AffineMatrix4x4, + orig_space_segmentation_path: str | Path | None = None, +) -> np.ndarray[Shape3d, np.dtype[np.int_]]: ... + + +def map_softlabels_to_orig( + cc_fn_softlabels: Image4d, + orig: nibabelImage, + orig2slab_vox2vox: AffineMatrix4x4, + cc_subseg_midslice: Image2d | None = None, + orig2midslice_vox2vox: AffineMatrix4x4 | None = None, + orig_space_segmentation_path: str | Path | None = None, +) -> np.ndarray[Shape3d, np.dtype[np.int_]]: + """Map soft labels back to original image space and apply post-processing. + + Parameters + ---------- + cc_fn_softlabels : np.ndarray + Soft label predictions of shape (H, W, D, C=3). + orig : nibabelImage + Original image. + orig2slab_vox2vox : AffineMatrix4x4 + The vox2vox transformation matrix from orig to the slab. + cc_subseg_midslice : np.ndarray, optional + Mask for subdividing regions of shape (H, D) (only paired with orig2midslice_vox2vox). + orig2midslice_vox2vox : AffineMatrix4x4, optional + The vox2vox transformation matrix from orig to the midslice (only paired with cc_subseg_midslice). + orig_space_segmentation_path : str or Path, optional + Path to save segmentation in original space. + + Returns + ------- + np.ndarray + Final segmentation in original image space. + + Notes + ----- + The function: + 1. Transforms background, cc, and fornix label channels separately. + 2. Transform CC subsegmentation from midslice to orig and paint into segmentation if `cc_subseg_midslice` is passed. + 4. Saves result to `orig_space_segmentation_path` if passed. + """ + # map softlabels to original image + def _map_softlabel_to_orig(data: Image3d, fill: int) -> Image3d: + # # Note: affine_transforms requires the inverse of the intended direction -> orig2slab + return affine_transform(data, orig2slab_vox2vox, output_shape=orig.shape, order=1, cval=fill) + + if cc_subseg_midslice is not None and orig2midslice_vox2vox is not None: + # map subdivision mask to orig space, this will also expand the labels into left-right direction + cc_subseg_orig_space_fut = thread_executor().submit( + affine_transform, + cc_subseg_midslice[None], + orig2midslice_vox2vox, # Note: affine_transforms requires the inverse of the intended direction + output_shape=orig.shape, + order=0, + mode="nearest", + ) + else: + cc_subseg_orig_space_fut = None + + _softlabels = np.moveaxis(cc_fn_softlabels, -1, 0) + softlabels_orig_space = np.stack(list(thread_executor().map(_map_softlabel_to_orig, _softlabels, [1., 0., 0.])), axis=-1) + # map to freesurfer labels + seg_lut = np.asarray([0, CC_LABEL, FORNIX_LABEL]) + seg_orig_space = seg_lut[np.argmax(softlabels_orig_space, axis=-1)] + + if cc_subseg_orig_space_fut is not None: + # replace CC_LABEL by subsegmentation labels + seg_orig_space = np.where(seg_orig_space == CC_LABEL, cc_subseg_orig_space_fut.result(), seg_orig_space) + + if orig_space_segmentation_path is not None: + logger.info(f"Saving segmentation in original space to {orig_space_segmentation_path}") + nib.save( + nib.MGHImage(seg_orig_space, orig.affine, orig.header), + orig_space_segmentation_path, + ) + return seg_orig_space + + +def interpolate_midplane( + orig: nibabelImage, + orig_fsaverage_vox2vox: AffineMatrix4x4, + slices_to_analyze: int, +) -> Image3d: + """Interpolates image data at the midplane using a grid of points. + + Parameters + ---------- + orig : nib.Nifti1Image + Original image. + orig_fsaverage_vox2vox : np.ndarray + Original to fsaverage space transformation matrix. + slices_to_analyze : int + Number of slices to analyze around midplane. + + Returns + ------- + np.ndarray + Interpolated image data at midplane. + """ + + # FIXME: this function is obsolete and can be removed - DK + + # slice_thickness = 9+slices_to_analyze-1 + # make grid of 9 slices in the fsaverage middle + # (cube from 123.5,0.5,0.5 to 132.5,255.5,255.5 (incudling end points, 1mm spacing)) + x_coords = np.linspace( + 124 - slices_to_analyze // 2, + 132 + slices_to_analyze // 2, + 9 + (slices_to_analyze - 1), + endpoint=True, + ) # 9 points from 123.5 to 132.5 + y_coords = np.linspace( + 0, orig.shape[1] - 1, orig.shape[1], endpoint=True + ) # 255 points from 0.5 to 255.5 + z_coords = np.linspace( + 0, orig.shape[2] - 1, orig.shape[2], endpoint=True + ) # 255 points from 0.5 to 255.5 + X, Y, Z = np.meshgrid(x_coords, y_coords, z_coords, indexing="ij") + + # Stack coordinates and add homogeneous coordinate + grid_fsaverage = np.stack([X.ravel(), Y.ravel(), Z.ravel(), np.ones(X.size)]) + + # move grid to orig space by applying transform + grid_orig = np.linalg.inv(orig_fsaverage_vox2vox) @ grid_fsaverage + + # interpolate grid on orig image + from scipy.ndimage import map_coordinates + + transformed = map_coordinates( + np.asarray(orig.dataobj), + grid_orig[0:3, :], # use only x,y,z coordinates (drop homogeneous coordinate) + order=2, + mode="constant", + cval=0, + prefilter=True, + ).reshape(len(x_coords), len(y_coords), len(z_coords)) + + return transformed diff --git a/CorpusCallosum/utils/types.py b/CorpusCallosum/utils/types.py new file mode 100644 index 000000000..52b45f9cf --- /dev/null +++ b/CorpusCallosum/utils/types.py @@ -0,0 +1,74 @@ +from typing import Literal, TypedDict + +from numpy import dtype, ndarray + +from FastSurferCNN.utils import ScalarType + +__all__ = [ + "CCMeasuresDict", + "ContourList", + "ContourThickness", + "Points2dType", + "Points3dType", + "Polygon2dType", + "Polygon3dType", + "SliceSelection", + "SubdivisionMethod", +] + +Polygon2dType = ndarray[tuple[Literal[2], int], dtype[ScalarType]] +Polygon3dType = ndarray[tuple[Literal[3], int], dtype[ScalarType]] +Points2dType = ndarray[tuple[int, Literal[2]], dtype[ScalarType]] +Points3dType = ndarray[tuple[int, Literal[3]], dtype[ScalarType]] +ContourList = list[Polygon2dType] +ContourThickness = ndarray[tuple[Literal[3], int], dtype[ScalarType]] +SliceSelection = Literal["middle", "all"] | int +SubdivisionMethod = Literal["shape", "vertical", "angular", "eigenvector"] + +class CCMeasuresDict(TypedDict): + """TypedDict for corpus callosum measures. + + Attributes + ---------- + cc_index : float + Corpus callosum shape index. + circularity : float + Shape circularity measure. + areas : np.ndarray + Areas of subdivided regions. + midline_length : float + Length along the midline. + thickness : float + Array of thickness measurements. + curvature : float + Array of curvature measurements. + thickness_profile : np.ndarray of type float + Thickness measurements along the contour. + total_area : float + Total area of the CC. + total_perimeter : float + Total perimeter length. + split_contours : list of np.ndarray + Subdivided contour segments in AS-slice coordinates. + midline_equidistant : np.ndarray + Equidistant points along midline in AS-slice coordinates. + levelpaths : list of np.ndarray + Paths for thickness measurements in AS-slice coordinates. + slice_index : int + Index of the processed slice. + """ + cc_index: float + circularity: float + areas: ndarray + midline_length: float + thickness: float + curvature: float + thickness_profile: ndarray[tuple[int], dtype[float]] + total_area: float + total_perimeter: float + total_area: float + total_perimeter: float + split_contours: ContourList + midline_equidistant: ndarray + levelpaths: list[ndarray] + slice_index: int diff --git a/CorpusCallosum/utils/visualization.py b/CorpusCallosum/utils/visualization.py new file mode 100644 index 000000000..3a44c6344 --- /dev/null +++ b/CorpusCallosum/utils/visualization.py @@ -0,0 +1,231 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +import matplotlib +import matplotlib.pyplot as plt +import nibabel as nib +import numpy as np + + +def plot_standardized_space( + ax_row: list[plt.Axes], + vol: np.ndarray, + ac_coords: np.ndarray, + pc_coords: np.ndarray +) -> None: + """Plot standardized space visualization across three views. + + Parameters + ---------- + ax_row : list[plt.Axes] + Row of axes to plot on (should be length 3). + vol : np.ndarray + Volume data to visualize. + ac_coords : np.ndarray + AC coordinates in standardized space. + pc_coords : np.ndarray + PC coordinates in standardized space. + + Notes + ----- + Creates three views: + - Axial (top view) + - Sagittal (side view) + - Coronal (front view) + """ + ax_row[0].set_title("Standardized") + + for i, (a, b, _) in ((2, 1, "Axial"), (2, 0, "Sagittal"), (1, 0, "Coronal")): + ax_row[i].scatter(ac_coords[a], ac_coords[b], color="red", marker="x") + ax_row[i].scatter(pc_coords[a], pc_coords[b], color="blue", marker="x") + ax_row[i].imshow(vol[(slice(None),) * i + (vol.shape[i] // 2,)], cmap="gray") + + +def visualize_coordinate_spaces( + orig: "nib.Nifti1Image", + upright: np.ndarray, + standardized: np.ndarray, + ac_coords_orig: np.ndarray, + pc_coords_orig: np.ndarray, + ac_coords_3d: np.ndarray, + pc_coords_3d: np.ndarray, + ac_coords_standardized: np.ndarray, + pc_coords_standardized: np.ndarray, + output_plot_path: str | Path, +) -> None: + """Visualize the AC and PC coordinates in different coordinate spaces. + + Creates a figure showing the anterior and posterior commissure points + in three different coordinate spaces for testing/debugging. + + Parameters + ---------- + orig : nibabel.Nifti1Image + Original image volume. + upright : np.ndarray + Volume in fsaverage space. + standardized : np.ndarray + Volume in standardized space. + ac_coords_orig : np.ndarray + AC coordinates in original space. + pc_coords_orig : np.ndarray + PC coordinates in original space. + ac_coords_3d : np.ndarray + AC coordinates in fsaverage space. + pc_coords_3d : np.ndarray + PC coordinates in fsaverage space. + ac_coords_standardized : np.ndarray + AC coordinates in standardized space. + pc_coords_standardized : np.ndarray + PC coordinates in standardized space. + output_plot_path : str or Path + Directory to save visualization. + + Notes + ----- + Saves a visualization of the anterior (red) and posterior (blue) commisure in three different view: + 1. the orig image (orig), + 2. fs-average standardized image space, and + 3. standardized image space + as a single image named 'ac_pc_spaces.png' in `output_dir`. + """ + fig, ax = plt.subplots(3, 4) + ax = ax.T + + # Original space - using plot_standardized_space + plot_standardized_space(ax[0], np.asarray(orig.dataobj), ac_coords_orig, pc_coords_orig) + ax[0, 0].set_title("Orig") + + # Fsaverage space + plot_standardized_space(ax[1], upright, ac_coords_3d, pc_coords_3d) + ax[1, 0].set_title("Fsaverage") + + # Standardized space + plot_standardized_space(ax[2], standardized, ac_coords_standardized, pc_coords_standardized) + ax[2, 0].set_title("Standardized") + # Format all subplots + for a in ax.flatten(): + a.set_aspect("equal", adjustable="box") + a.axis("off") + + plt.savefig(output_plot_path, dpi=300, bbox_inches="tight") + plt.show() + plt.close() + + +def plot_contours( + transformed: np.ndarray, + split_contours: list[np.ndarray] | None = None, + midline_equidistant: np.ndarray | None = None, + levelpaths: list[np.ndarray] | None = None, + output_path: str | Path | list[Path] | None = None, + ac_coords: np.ndarray | None = None, + pc_coords: np.ndarray | None = None, + vox_size: tuple[float, float, float] | None = None, + title: str = "", +) -> None: + """Creates a figure of the contours (shape) and the subdivisions of the corpus callosum. + + Parameters + ---------- + transformed : np.ndarray + Transformed image data. + split_contours : list[np.ndarray], optional + List of contour arrays for each subdivision (ignore contours on None). + midline_equidistant : np.ndarray, optional + Midline points at equidistant spacing (ignore midline on None). + levelpaths : list[np.ndarray], optional + List of level paths for visualization (ignore level paths on None). + output_path : str or Path or list of Paths, optional + Path to save the plot (do not save on None). + ac_coords : np.ndarray, optional + AC coordinates for visualization (ignore AC on None). + pc_coords : np.ndarray, optional + PC coordinates for visualization (ignore PC on None). + vox_size : triplet of floats, optional + LIA-oriented voxel size for scaling, optional if none of split_contours, midline_equidistant, or levelpaths are + provided. + title : str, default="" + Title for the plot. + + Notes + ----- + Creates a visualization of the corpus callosum contours and their subdivisions. + If output_path is provided, saves the plot to that location. + """ + + if vox_size is None and None in (split_contours, midline_equidistant, levelpaths): + raise ValueError("vox_size must be provided if split_contours, midline_equidistant, or levelpaths are given.") + + if output_path is not None: + matplotlib.use('Agg') # Use non-GUI backend + + # convert vox_size from LIA to AS + vox_size_ras = np.asarray([vox_size[0], vox_size[2], vox_size[1]]) if vox_size is not None else None + + # scale contour data by vox_size to convert from AS to AS-aligned voxel space + _split_contours = [] if split_contours is None else [sp / vox_size_ras[1:, None] for sp in split_contours] + _midline_equi = np.zeros((0, 2)) if midline_equidistant is None else midline_equidistant / vox_size_ras[None, 1:] + _levelpaths = [] if levelpaths is None else [lp / vox_size_ras[None, 1:] for lp in levelpaths] + + has_first_plot = not (len(_split_contours) == 0 and ac_coords is None and pc_coords is None) + num_plots = 1 + int(has_first_plot) + + fig, ax = plt.subplots(1, num_plots, sharex=True, sharey=True, figsize=(15, 10)) + + # NOTE: For all plots imshow shows y inverted + current_plot = 0 + + if has_first_plot: + ax[current_plot].imshow(transformed[transformed.shape[0] // 2], cmap="gray") + ax[current_plot].set_title(title) + if _split_contours: + for i, this_contour in enumerate(_split_contours): + ax[current_plot].fill(this_contour[0, :], -this_contour[1, :], color="steelblue", alpha=0.25) + kwargs = {"color": "mediumblue", "linewidth": 0.7, "linestyle": "solid" if i != 0 else "dotted"} + ax[current_plot].plot(this_contour[0, :], -this_contour[1, :], **kwargs) + if ac_coords is not None: + ax[current_plot].scatter(ac_coords[1], ac_coords[0], color="red", marker="x") + if pc_coords is not None: + ax[current_plot].scatter(pc_coords[1], pc_coords[0], color="blue", marker="x") + current_plot += int(has_first_plot) + + ax[current_plot].imshow(transformed[transformed.shape[0] // 2], cmap="gray") + for this_path in _levelpaths: + ax[current_plot].plot(this_path[:, 0], -this_path[:, 1], color="brown", linewidth=0.8) + ax[current_plot].set_title("Midline & Levelpaths") + if _midline_equi.shape[0] > 0: + ax[current_plot].plot(_midline_equi[:, 0], -_midline_equi[:, 1], color="red") + if _split_contours: + reference_contour = _split_contours[0] + ax[current_plot].plot(reference_contour[0, :], -reference_contour[1, :], color="red", linewidth=0.5) + + padding = 30 + for a in ax.flatten(): + a.set_aspect("equal", adjustable="box") + a.axis("off") + if _split_contours: + reference_contour = _split_contours[0] + # get bounding box of contours + a.set_xlim(reference_contour[0, :].min() - padding, reference_contour[0, :].max() + padding) + a.set_ylim((-reference_contour[1, :]).max() + padding, (-reference_contour[1, :]).min() - padding) + + if output_path is None: + return plt.show() + for _output_path in (output_path if isinstance(output_path, (list, tuple)) else [output_path]): + Path(_output_path).parent.mkdir(parents=True, exist_ok=True) + fig.savefig(_output_path, dpi=300, bbox_inches="tight") + return None diff --git a/FastSurferCNN/data_loader/conform.py b/FastSurferCNN/data_loader/conform.py index d61cc58d9..e49924f99 100644 --- a/FastSurferCNN/data_loader/conform.py +++ b/FastSurferCNN/data_loader/conform.py @@ -19,7 +19,7 @@ import re import sys from collections.abc import Callable, Iterable, Sequence -from typing import TYPE_CHECKING, Literal, TypeVar, cast +from typing import TYPE_CHECKING, Literal, TypeVar, Union, cast import nibabel import nibabel as nib @@ -29,13 +29,8 @@ if TYPE_CHECKING: import torch -else: - # stub imports so TypeVar works - class torch: - class Tensor: - pass -from FastSurferCNN.utils import logging +from FastSurferCNN.utils import ScalarType, logging, nibabelImage from FastSurferCNN.utils.arg_types import ImageSizeOption, OrientationType, StrictOrientationType, VoxSizeOption from FastSurferCNN.utils.arg_types import float_gt_zero_and_le_one as __conform_to_one_mm from FastSurferCNN.utils.arg_types import img_size as __img_size @@ -64,9 +59,8 @@ class Tensor: LOGGER = logging.getLogger(__name__) -_TA = TypeVar("_TA", bound=np.ndarray | torch.Tensor) -_TB = TypeVar("_TB", bound=np.ndarray | torch.Tensor) -_TScalarType = TypeVar("_TScalarType", bound=np.number) +_TA = TypeVar("_TA", bound=Union[np.ndarray, "torch.Tensor"]) +_TB = TypeVar("_TB", bound=Union[np.ndarray, "torch.Tensor"]) def __rescale_type(a: str) -> float | int | None: @@ -373,21 +367,21 @@ def apply_orientation(arr: _TB | npt.ArrayLike, ornt: npt.NDArray[int]) -> _TB: def map_image( - img: nib.analyze.SpatialImage, + img: nibabelImage, out_affine: npt.NDArray[float], out_shape: tuple[int, ...] | npt.NDArray[int] | Iterable[int], ras2ras: npt.NDArray[np.number] | None = None, order: int = 1, - dtype: np.dtype[_TScalarType] | npt.DTypeLike | None = None, + dtype: np.dtype[ScalarType] | npt.DTypeLike | None = None, vox_eps: float = 1e-4, rot_eps: float = 1e-6, -) -> npt.NDArray[_TScalarType]: +) -> npt.NDArray[ScalarType]: """ Map image to new voxel space (RAS orientation). Parameters ---------- - img : nib.analyze.SpatialImage + img : nibabelImage The src 3D image with data and affine set. out_affine : np.ndarray Trg image affine. @@ -635,7 +629,7 @@ def rescale( def conform( - img: nib.analyze.SpatialImage, + img: nibabelImage, order: int = 1, vox_size: VoxSizeOption | None = 1.0, img_size: ImageSizeOption | None = 256, @@ -646,7 +640,7 @@ def conform( vox_eps: float = 1e-4, rot_eps: float = 1e-6, **kwargs, -) -> nib.analyze.SpatialImage: +) -> nibabelImage: """Python version of mri_convert -c. mri_convert -c by default turns image intensity values into UCHAR, reslices images to standard position, fills up @@ -654,7 +648,7 @@ def conform( Parameters ---------- - img : nib.analyze.SpatialImage + img : nibabelImage Loaded source image. order : int, default=1 Interpolation order (0=nearest, 1=linear, 2=quadratic, 3=cubic). @@ -777,7 +771,7 @@ def conform( def prepare_mgh_header( - img: nib.analyze.SpatialImage, + img: nibabelImage, target_vox_size: npt.NDArray[float] | None = None, target_img_size: npt.NDArray[int] | None = None, orientation: OrientationType = "native", @@ -903,7 +897,7 @@ def isclose(x, y, eps): def is_conform( - img: nib.analyze.SpatialImage, + img: nibabelImage, vox_size: VoxSizeOption | None = 1.0, img_size: ImageSizeOption | None = 256, dtype: npt.DTypeLike | None = np.uint8, @@ -921,7 +915,7 @@ def is_conform( Parameters ---------- - img : nib.analyze.SpatialImage + img : nibabelImage Loaded source image. vox_size : float, "min", None, default=1.0 Which voxel size to conform to. Can either be a float between 0.0 and 1.0, 'min' (to check, whether the image is @@ -1101,7 +1095,7 @@ def is_orientation( def conformed_vox_img_size( - img: nib.analyze.SpatialImage, + img: nibabelImage, vox_size: VoxSizeOption | None, img_size: ImageSizeOption | None, threshold_1mm: float | None = None, @@ -1115,7 +1109,7 @@ def conformed_vox_img_size( Parameters ---------- - img : nib.analyze.SpatialImage + img : nibabelImage Loaded source image. vox_size : float, "min", None The voxel size parameter to use: either a voxel size as float, or the string "min" to automatically find a diff --git a/FastSurferCNN/data_loader/data_utils.py b/FastSurferCNN/data_loader/data_utils.py index 0a2e8f2ca..6c036fef5 100644 --- a/FastSurferCNN/data_loader/data_utils.py +++ b/FastSurferCNN/data_loader/data_utils.py @@ -35,7 +35,7 @@ from skimage.measure import label, regionprops from FastSurferCNN.data_loader.conform import check_affine_in_nifti, conform, is_conform -from FastSurferCNN.utils import logging +from FastSurferCNN.utils import logging, nibabelImage ## # Global Vars @@ -88,7 +88,7 @@ def load_and_conform_image( If input has multiple input frames or inconsistent nifti headers. """ img_file = Path(img_filename) - orig = cast(nib.analyze.SpatialImage, nib.load(img_file)) + orig = cast(nibabelImage, nib.load(img_file)) # is_conform and conform accept numeric values and the string 'min' instead of the bool value if not is_conform(orig, **conform_kwargs): @@ -113,7 +113,7 @@ def load_image( file: str | Path, name: str = "image", **kwargs, -) -> tuple[nib.analyze.SpatialImage, np.ndarray]: +) -> tuple[nibabelImage, np.ndarray]: """ Load file 'file' with nibabel, including all data. @@ -128,9 +128,10 @@ def load_image( Returns ------- - Tuple[nib.analyze.SpatialImage, np.ndarray] - The nibabel image object and a numpy array of the data. - + the_image : nibabelImage + The SpatialImage object from nibabel of the conformed image (including updated affine). + the_data : np.ndarray + The data of the conformed image. Raises ------ IOError @@ -146,7 +147,7 @@ def load_image( } """ try: - img = cast(nib.analyze.SpatialImage, nib.load(file, **kwargs)) + img = cast(nibabelImage, nib.load(file, **kwargs)) except (OSError, FileNotFoundError) as e: raise OSError(f"Failed loading the {name} '{file}' with error: {e.args[0]}") from e return img, np.asarray(img.dataobj) @@ -156,7 +157,7 @@ def load_maybe_conform( file: Path | str, alt_file: Path | str, **conform_kwargs, -) -> tuple[Path, nib.analyze.SpatialImage, np.ndarray]: +) -> tuple[Path, nibabelImage, np.ndarray]: """ Load an image by file, check whether it is conformed to vox_size and conform to vox_size if it is not. @@ -174,7 +175,7 @@ def load_maybe_conform( ------- Path The path to the file. - nib.analyze.SpatialImage + nibabelImage The file container object including the corrected header. np.ndarray The data loaded from the file. @@ -192,7 +193,7 @@ def load_maybe_conform( _is_conform, img = False, None if file.is_file(): # see if the file is 1mm - img = cast(nib.analyze.SpatialImage, nib.load(file)) + img = cast(nibabelImage, nib.load(file)) # is_conform only needs the header, not the data _is_conform = is_conform(img, **conform_kwargs_is_conform, verbose=False, vox_eps=0.1) diff --git a/FastSurferCNN/download_checkpoints.py b/FastSurferCNN/download_checkpoints.py index 23f65febd..35492d79e 100644 --- a/FastSurferCNN/download_checkpoints.py +++ b/FastSurferCNN/download_checkpoints.py @@ -17,6 +17,7 @@ from CerebNet.utils.checkpoint import ( YAML_DEFAULT as CEREBNET_YAML, ) +from CorpusCallosum.utils.checkpoint import YAML_DEFAULT as CC_YAML from FastSurferCNN.utils import PLANES from FastSurferCNN.utils.checkpoint import ( YAML_DEFAULT as VINN_YAML, @@ -26,9 +27,7 @@ get_checkpoints, load_checkpoint_config_defaults, ) -from HypVINN.utils.checkpoint import ( - YAML_DEFAULT as HYPVINN_YAML, -) +from HypVINN.utils.checkpoint import YAML_DEFAULT as HYPVINN_YAML class ConfigCache: @@ -40,9 +39,12 @@ def cerebnet_url(self): def hypvinn_url(self): return load_checkpoint_config_defaults("url", filename=HYPVINN_YAML) + + def cc_url(self): + return load_checkpoint_config_defaults("url", filename=CC_YAML) def all_urls(self): - return self.vinn_url() + self.cerebnet_url() + self.hypvinn_url() + return self.vinn_url() + self.cerebnet_url() + self.hypvinn_url() + self.cc_url() defaults = ConfigCache() @@ -72,6 +74,12 @@ def make_parser(): action="store_true", help="Check and download CerebNet default checkpoints", ) + parser.add_argument( + "--cc", + default=False, + action="store_true", + help="Check and download Corpus Callosum default checkpoints", + ) parser.add_argument( "--hypvinn", @@ -99,14 +107,15 @@ def make_parser(): def main( - vinn: bool, - cerebnet: bool, - hypvinn: bool, - all: bool, - files: list[str], + vinn: bool = False, + cerebnet: bool = False, + hypvinn: bool = False, + cc: bool = False, + all: bool = False, + files: list[str] = (), url: str | None = None, ) -> int | str: - if not vinn and not files and not cerebnet and not hypvinn and not all: + if not vinn and not files and not cerebnet and not hypvinn and not cc and not all: return ("Specify either files to download or --vinn, --cerebnet, " "--hypvinn, or --all, see help -h.") @@ -141,6 +150,16 @@ def main( *(hypvinn_config[plane] for plane in PLANES), urls=defaults.hypvinn_url() if url is None else [url], ) + # Corpus Callosum checkpoints + if cc or all: + cc_config = load_checkpoint_config_defaults( + "checkpoint", + filename=CC_YAML, + ) + get_checkpoints( + *(cc_config[model] for model in cc_config.keys()), + urls=defaults.cc_url() if url is None else [url], + ) for fname in files: check_and_download_ckpts( fname, diff --git a/FastSurferCNN/models/interpolation_layer.py b/FastSurferCNN/models/interpolation_layer.py index 947207b79..790bf6a77 100644 --- a/FastSurferCNN/models/interpolation_layer.py +++ b/FastSurferCNN/models/interpolation_layer.py @@ -144,40 +144,25 @@ def forward( should be equal to `target_voxelsize / source_voxelsize`. """ if self._N == -1: - raise RuntimeError( - "Direct instantiation of _InterpolateNd is not supported." - ) + raise RuntimeError("Direct instantiation of _InterpolateNd is not supported.") if input_tensor.dim() != 2 + self._N: - raise ValueError( - f"Expected {self._N+2}-dimensional input tensor, got {input_tensor.dim()}" - ) + raise ValueError(f"Expected {self._N+2}-dimensional input tensor, got {input_tensor.dim()}") if len(self._target_shape) == 0: - raise AttributeError( - "The target_shape was not set, but a valid value is required." - ) + raise AttributeError("The target_shape was not set, but a valid value is required.") - scales_chunks = list( - zip(*self._fix_scale_factors(scale_factors, input_tensor.shape[0]), strict=False) - ) + scales_chunks = list(zip(*self._fix_scale_factors(scale_factors, input_tensor.shape[0]), strict=False)) if len(scales_chunks) == 0: - raise ValueError( - f"Invalid scale_factors {scale_factors}, no chunks returned." - ) + raise ValueError(f"Invalid scale_factors {scale_factors}, no chunks returned.") scales, chunks = map(list, scales_chunks) interp, scales_out = [], [] # Pytorch Tensor shape BxCxHxW --> loop over batches, interpolate single images, concatenate output at end - for tensor, scale_f, num in zip( - torch.split(input_tensor, chunks, dim=0), scales, chunks, strict=False - ): + for tensor, scale_f, num in zip(torch.split(input_tensor, chunks, dim=0), scales, chunks, strict=False): if rescale: - if isinstance(scale_f, list): - scale_f = [1 / sf for sf in scale_f] - else: - scale_f = torch.div(1, scale_f) + scale_f = [1 / sf for sf in scale_f] if isinstance(scale_f, list) else torch.div(1, scale_f) image, sf = self._interpolate(tensor, scale_f) interp.append(image) scales_out.extend([sf] * num) @@ -453,11 +438,7 @@ def _interpolate( _T.Tuple[Tensor, T_Scale] The interpolated tensor and its scaling factor. """ - scale_factor = ( - scale_factor.tolist() - if isinstance(scale_factor, np.ndarray) - else scale_factor - ) + scale_factor = scale_factor.tolist() if isinstance(scale_factor, np.ndarray) else scale_factor if isinstance(scale_factor, Tensor) and scale_factor.shape == (2,): pass elif isinstance(scale_factor, _T.Sequence) and len(scale_factor) == 2: diff --git a/FastSurferCNN/quick_qc.py b/FastSurferCNN/quick_qc.py index d74a20de2..05bd4a2af 100644 --- a/FastSurferCNN/quick_qc.py +++ b/FastSurferCNN/quick_qc.py @@ -174,10 +174,12 @@ def get_ventricle_bg_intersection_volume(seg_array, voxvol): if __name__ == "__main__": + from FastSurferCNN.utils import nibabelImage + # Command Line options are error checking done here options = options_parse() print(f"Reading in aparc+aseg: {options.asegdkt_segfile} ...") - inseg = cast(nib.analyze.SpatialImage, nib.load(options.asegdkt_segfile)) + inseg = cast(nibabelImage, nib.load(options.asegdkt_segfile)) inseg_data = np.asanyarray(inseg.dataobj) inseg_header = inseg.header inseg_voxvol = np.prod(inseg_header.get_zooms()) diff --git a/FastSurferCNN/reduce_to_aseg.py b/FastSurferCNN/reduce_to_aseg.py index 3175fc904..72f7e1383 100644 --- a/FastSurferCNN/reduce_to_aseg.py +++ b/FastSurferCNN/reduce_to_aseg.py @@ -25,14 +25,13 @@ from skimage.filters import gaussian from skimage.measure import label -from FastSurferCNN.utils import logging +from FastSurferCNN.utils import logging, nibabelHeader, nibabelImage, ShapeType, AffineMatrix4x4 from FastSurferCNN.utils.brainvolstats import mask_in_array from FastSurferCNN.utils.logging import setup_logging from FastSurferCNN.utils.parallel import thread_executor _T = TypeVar("_T", bound=np.number) _TDType = np.dtype[_T] -_TShape = TypeVar("_TShape", bound=tuple[int, ...]) LOGGER = logging.getLogger(__name__) @@ -104,7 +103,7 @@ def options_parse(): return options -def reduce_to_aseg(data_inseg: np.ndarray[_TShape, _TDType]) -> np.ndarray[_TShape, _TDType]: +def reduce_to_aseg(data_inseg: np.ndarray[ShapeType, _TDType]) -> np.ndarray[ShapeType, _TDType]: """ Reduce the input segmentation to a simpler segmentation (for all data orientations, LIA/etc). @@ -127,7 +126,8 @@ def reduce_to_aseg(data_inseg: np.ndarray[_TShape, _TDType]) -> np.ndarray[_TSha return data_inseg -def create_mask(aseg_data: np.ndarray[_TShape, _TDType], dnum: int, enum: int) -> np.ndarray[_TShape, np.dtype[bool]]: +def create_mask(aseg_data: np.ndarray[ShapeType, _TDType], dnum: int, enum: int) \ + -> np.ndarray[ShapeType, np.dtype[np.bool_]]: """ Create dilated mask (works for all data orientations, LIA/etc). @@ -180,7 +180,7 @@ def create_mask(aseg_data: np.ndarray[_TShape, _TDType], dnum: int, enum: int) - return datab.astype(np.uint8) -def flip_wm_islands(aseg_data: np.ndarray[_TShape, _TDType]) -> np.ndarray[_TShape, _TDType]: +def flip_wm_islands(aseg_data: np.ndarray[ShapeType, _TDType]) -> np.ndarray[ShapeType, _TDType]: """ Flip labels of disconnected white matter islands to the other hemisphere (works for all data orientations, LIA/etc). @@ -204,7 +204,7 @@ def flip_wm_islands(aseg_data: np.ndarray[_TShape, _TDType]) -> np.ndarray[_TSha rh_wm = 41 rh_gm = 42 - def _islands(data: np.ndarray[_TShape, _TDType], _label: int) -> np.ndarray[_TShape, np.dtype[bool]]: + def _islands(data: np.ndarray[ShapeType, _TDType], _label: int) -> np.ndarray[ShapeType, np.dtype[np.bool_]]: # for lh get largest component and islands mask = data == _label labels = label(mask, background=0) @@ -234,11 +234,11 @@ def _islands(data: np.ndarray[_TShape, _TDType], _label: int) -> np.ndarray[_TSh def create_mask_and_save( - seg: np.ndarray[_TShape, np.dtype], - seg_affine: np.ndarray[tuple[int, int], np.dtype[float]], - seg_header: nib.analyze.SpatialHeader, + seg: np.ndarray[ShapeType, np.dtype], + seg_affine: AffineMatrix4x4, + seg_header: nibabelHeader, filename: Path | None = None, -) -> np.ndarray[_TShape, np.dtype[np.uint8]]: +) -> np.ndarray[ShapeType, np.dtype[np.uint8]]: """Convenience function for brainmask generation plus saving.""" mask_data = create_mask(seg, 5, 4) if filename is not None: @@ -249,11 +249,11 @@ def create_mask_and_save( def reduce_to_aseg_and_save( - seg: np.ndarray[_TShape, np.dtype], - seg_affine: np.ndarray[tuple[int, int], np.dtype[float]], - seg_header: nib.analyze.SpatialHeader, + seg: np.ndarray[ShapeType, np.dtype], + seg_affine: AffineMatrix4x4, + seg_header: nibabelHeader, filename: Path | None = None, -) -> np.ndarray[_TShape, np.dtype[np.uint8]]: +) -> np.ndarray[ShapeType, np.dtype[np.uint8]]: """Convenience function for reduce_to_aseg plus saving.""" _data = reduce_to_aseg(seg) @@ -271,7 +271,7 @@ def reduce_to_aseg_and_save( setup_logging() LOGGER.info(f"Reading in aparc+aseg: {options.input_seg} ...") - inseg = cast(nib.analyze.SpatialImage, nib.load(options.input_seg)) + inseg = cast(nibabelImage, nib.load(options.input_seg)) inseg_data = np.asanyarray(inseg.dataobj) inseg_header = inseg.header inseg_affine = inseg.affine diff --git a/FastSurferCNN/run_prediction.py b/FastSurferCNN/run_prediction.py index e73123e64..53bb5e1d2 100644 --- a/FastSurferCNN/run_prediction.py +++ b/FastSurferCNN/run_prediction.py @@ -31,7 +31,6 @@ from pathlib import Path from typing import Any, Literal -import nibabel as nib import numpy as np import torch import yacs.config @@ -42,7 +41,7 @@ from FastSurferCNN.data_loader.conform import conform, is_conform, orientation_to_ornts, to_target_orientation from FastSurferCNN.inference import Inference from FastSurferCNN.quick_qc import check_volume -from FastSurferCNN.utils import PLANES, Plane, logging, parser_defaults +from FastSurferCNN.utils import PLANES, Plane, logging, nibabelImage, parser_defaults from FastSurferCNN.utils.arg_types import OrientationType, VoxSizeOption from FastSurferCNN.utils.arg_types import vox_size as _vox_size from FastSurferCNN.utils.checkpoint import get_checkpoints, load_checkpoint_config_defaults @@ -294,7 +293,7 @@ def __conform_kwargs(self, **kwargs) -> dict[str, Any]: def conform_and_save_orig( self, subject: SubjectDirectory, - ) -> tuple[nib.analyze.SpatialImage, np.ndarray]: + ) -> tuple[nibabelImage, np.ndarray]: """ Conform and saves original image. @@ -305,8 +304,10 @@ def conform_and_save_orig( Returns ------- - tuple[nib.analyze.SpatialImage, np.ndarray] - Conformed image. + the_image : nibabelImage + The SpatialImage object from nibabel of the conformed image (including updated affine). + the_data : np.ndarray + The data of the conformed image. """ orig, orig_data = du.load_image(subject.orig_name, "orig image") LOGGER.info(f"Successfully loaded image from {subject.orig_name}.") @@ -405,7 +406,7 @@ def save_img( self, save_as: str | Path, data: np.ndarray | torch.Tensor, - orig: nib.analyze.SpatialImage, + orig: nibabelImage, dtype: type | None = None, ) -> None: """ @@ -417,7 +418,7 @@ def save_img( Filename to give the image. data : np.ndarray, torch.Tensor Image data. - orig : nib.analyze.SpatialImage + orig : nibabelImage Original Image. dtype : type, optional Data type to use for saving the image. If None, the original data type is used. @@ -441,7 +442,7 @@ def async_save_img( self, save_as: str | Path, data: np.ndarray | torch.Tensor, - orig: nib.analyze.SpatialImage, + orig: nibabelImage, dtype: type | None = None, ) -> Future[None]: """ @@ -454,7 +455,7 @@ def async_save_img( Filename to give the image. data : np.ndarray, torch.Tensor Image data. - orig : nib.analyze.SpatialImage + orig : nibabelImage Original Image. dtype : type, optional Data type to use for saving the image. If None, the original data type is used. @@ -491,7 +492,7 @@ def get_num_classes(self) -> int: def pipeline_conform_and_save_orig( self, subjects: SubjectList, - ) -> Iterator[tuple[SubjectDirectory, tuple[nib.analyze.SpatialImage, np.ndarray]]]: + ) -> Iterator[tuple[SubjectDirectory, tuple[nibabelImage, np.ndarray]]]: """ Pipeline for conforming and saving original images asynchronously. @@ -502,8 +503,15 @@ def pipeline_conform_and_save_orig( Yields ------ - tuple[SubjectDirectory, tuple[nib.analyze.SpatialImage, np.ndarray]] - Subject directory and a tuple with the image and its data. + subject_dir : SubjectDirectory + The SubjectDirectory object, that helps manage file names. + image_and_data : tuple of nibabelImage and np.ndarray + The tuple with the image and its data. + + See Also + -------- + RunModelOnData.conform_and_safe_orig + For more detailed description of `image_and_data`. """ if not self._async_io: # do not pipeline, direct iteration and function call diff --git a/FastSurferCNN/segstats.py b/FastSurferCNN/segstats.py index a67f3cd08..2c05b4215 100644 --- a/FastSurferCNN/segstats.py +++ b/FastSurferCNN/segstats.py @@ -19,38 +19,29 @@ import argparse import logging from collections.abc import Callable, Container, Iterable, Iterator, Sequence, Sized -from concurrent.futures import Executor, ThreadPoolExecutor +from concurrent.futures import Executor from functools import partial, reduce from itertools import product from numbers import Number from pathlib import Path -from typing import ( - IO, - Any, - Literal, - TypedDict, - TypeVar, - cast, - overload, -) - -import nibabel as nib +from typing import IO, Any, Literal, TypedDict, TypeVar, cast, overload + import numpy as np import pandas as pd from numpy import typing as npt +from FastSurferCNN.utils import nibabelImage from FastSurferCNN.utils.arg_types import float_gt_zero_and_le_one as robust_threshold from FastSurferCNN.utils.arg_types import int_ge_zero as id_type from FastSurferCNN.utils.arg_types import int_gt_zero as patch_size_type from FastSurferCNN.utils.brainvolstats import Manager, MeasureTuple, read_measure_file -from FastSurferCNN.utils.parallel import get_num_threads +from FastSurferCNN.utils.parallel import get_num_threads, set_num_threads, thread_executor from FastSurferCNN.utils.parser_defaults import add_arguments # Constants -USAGE = ("python segstats.py (-norm|-pv) -i " - "-o [optional arguments] [{measures,mri_segstats} ...]") -DESCRIPTION = ("Script to calculate partial volumes and other segmentation statistics " - "of a segmentation file.") +USAGE = ("python segstats.py (-norm|-pv) -i -o [optional arguments] " + "[{measures,mri_segstats} ...]") +DESCRIPTION = "Script to calculate partial volumes and other segmentation statistics of a segmentation file." VERSION = "1.1" HELPTEXT = f""" Dependencies: @@ -73,8 +64,7 @@ Revision: {VERSION} """ FILTER_SIZES = (3, 15) -COLUMNS = ["Index", "SegId", "NVoxels", "Volume_mm3", "StructName", "Mean", "StdDev", - "Min", "Max", "Range"] +COLUMNS = ["Index", "SegId", "NVoxels", "Volume_mm3", "StructName", "Mean", "StdDev", "Min", "Max", "Range"] # Type definitions _NumberType = TypeVar("_NumberType", bound=Number) @@ -88,6 +78,7 @@ float | None, float | None, float, npt.NDArray[bool]] SubparserCallback = type[argparse.ArgumentParser.add_subparsers] +DO_NOT_SAVE_FILE = Path("do not save the file") class _RequiredPVStats(TypedDict): SegId: int @@ -105,7 +96,7 @@ class _OptionalPVStats(TypedDict, total=False): class PVStats(_RequiredPVStats, _OptionalPVStats): - """Dictionary of volume statistics for partial volume evaluation and global stats""" + """Dictionary of volume statistics for partial volume evaluation and global stats.""" pass @@ -197,12 +188,17 @@ def make_arguments(helpformatter: bool = False) -> argparse.ArgumentParser: """ import sys if helpformatter: + # Help Formatter for command line kwargs = { "epilog": HELPTEXT.replace("\n", "
"), "formatter_class": HelpFormatter, } else: - kwargs = {"epilog": HELPTEXT} + # Help Formatter for documentation + kwargs = { + "epilog": HELPTEXT, + # "formatter_class": DocHelpFormatter, + } parser = argparse.ArgumentParser( usage=USAGE, description=DESCRIPTION, @@ -215,18 +211,16 @@ def make_arguments(helpformatter: bool = False) -> argparse.ArgumentParser: "-pv", type=Path, dest="pvfile", - help="Path to image used to compute the partial volume effects (default: the " - "file passed as normfile). This file is required, either directly or " - "indirectly via normfile.", + help="Path to image used to compute the partial volume effects (default: the file passed as normfile). This " + "file is required, either directly or indirectly via normfile.", ) parser.add_argument( "-norm", "--normfile", type=Path, dest="normfile", - help="Path to biasfield-corrected image (the same image space as " - "segmentation). This file is used to calculate intensity values. Also, if " - "no pvfile is defined, it is used as pvfile. One of normfile or pvfile is " + help="Path to biasfield-corrected image (the same image space as segmentation). This file is used to calculate " + "intensity values. Also, if no pvfile is defined, it is used as pvfile. One of normfile or pvfile is " "required.", ) parser.add_argument( @@ -251,15 +245,13 @@ def make_arguments(helpformatter: bool = False) -> argparse.ArgumentParser: type=id_type, nargs="*", default=[], - help="List of segmentation ids (integers) to exclude in analysis, " - "e.g. `--excludeid 0 1 10` (default: None).", + help="List of segmentation ids (integers) to exclude in analysis, e.g. `--excludeid 0 1 10` (default: None).", ) parser.add_argument( "--ids", type=id_type, nargs="*", - help="List of exclusive segmentation ids (integers) to use " - "(default: all ids in --lut or all ids in image).", + help="List of exclusive segmentation ids (integers) to use (default: all ids in --lut or all ids in image).", ) parser.add_argument( "--merged_label", @@ -268,20 +260,17 @@ def make_arguments(helpformatter: bool = False) -> argparse.ArgumentParser: dest="merged_labels", default=[], action="append", - help="Add a 'virtual' label (first value) that is the combination of all " - "following values, e.g. `--merged_label 100 3 4 8` will compute the " - "statistics for label 100 by aggregating labels 3, 4 and 8.", + help="Add a 'virtual' label (first value) that is the combination of all following values, e.g. " + "`--merged_label 100 3 4 8` will compute the statistics for label 100 by aggregating labels 3, 4 and 8.", ) parser.add_argument( "--robust", type=robust_threshold, dest="robust", default=None, - help="Whether to calculate robust segmentation metrics. This parameter " - "expects the fraction of values to keep, e.g. `--robust 0.95` will " - "ignore the 2.5%% smallest and the 2.5%% largest values in the " - "segmentation when calculating the statistics (default: no robust " - "statistics == `--robust 1.0`).", + help="Whether to calculate robust segmentation metrics. This parameter expects the fraction of values to keep, " + "e.g. `--robust 0.95` will ignore the 2.5%% smallest and the 2.5%% largest values in the segmentation " + "when calculating the statistics (default: no robust statistics == `--robust 1.0`).", ) parser.add_argument( "--measure_only", @@ -298,9 +287,8 @@ def make_arguments(helpformatter: bool = False) -> argparse.ArgumentParser: "--threads", dest="threads", default=get_num_threads(), - type=int, - help=f"Number of threads to use (defaults to number of hardware threads: " - f"{get_num_threads()})", + type=set_num_threads, + help=f"Number of threads to use (defaults to number of hardware threads: {get_num_threads()})", ) advanced.add_argument( "--patch_size", @@ -313,8 +301,7 @@ def make_arguments(helpformatter: bool = False) -> argparse.ArgumentParser: "--empty", action="store_true", dest="empty", - help="Keep ids for the table that do not exist in the segmentation " - "(default: drop).", + help="Keep ids for the table that do not exist in the segmentation (default: drop).", ) add_arguments(advanced, ["device", "sid", "sd"]) advanced.add_argument( @@ -329,57 +316,53 @@ def make_arguments(helpformatter: bool = False) -> argparse.ArgumentParser: action="store_true", dest="legacy_freesurfer", help="Reproduce FreeSurfer mri_segstats numbers (default: off). \n" - "Please note, that exact agreement of numbers cannot be guaranteed, " - "because the condition number of FreeSurfers algorithm (mri_segstats) " - "combined with the fact that mri_segstats uses 'float' to measure the " - "partial volume corrected volume. This yields differences of more than " - "60mm3 or 0.1%% in large structures. This uniquely impacts highres images " - "with more voxels (on the boundary) and smaller voxel sizes (volume per " - "voxel).", + "Please note, that exact agreement of numbers cannot be guaranteed, because the condition number of " + "FreeSurfers algorithm (mri_segstats) combined with the fact that mri_segstats uses 'float' to measure " + "the partial volume corrected volume. This yields differences of more than 60mm3 or 0.1%% in large " + "structures. This uniquely impacts highres images with more voxels (on the boundary) and smaller voxel " + "sizes (volume per voxel).", ) # Additional info: - # Changing the data type in mri_segstats to double can reduce this difference to - # nearly zero. + # Changing the data type in mri_segstats to double can reduce this difference to nearly zero. # mri_segstats has two operations affecting a bad condition number: # 1. pv = (val - mean_nbr) / (mean_label - mean_nbr) # 2. volume += vox_vol * pv - # This is further affected by the small vox_vol (volume per voxel) of highres - # images (0.7iso -> 0.343) - # Their effects stack and can result in differences of more than 60mm3 or 0.1% in - # a comparison between double and single-precision evaluations. + # This is further affected by the small vox_vol (volume per voxel) of highres images (0.7iso -> 0.343) + # Their effects stack and can result in differences of more than 60mm3 or 0.1% in a comparison between double and + # single-precision evaluations. advanced.add_argument( "--mixing_coeff", type=Path, dest="mix_coeff", - default="", + default=DO_NOT_SAVE_FILE, help="Save the mixing coefficients (default: off).", ) advanced.add_argument( "--alternate_labels", type=Path, dest="nbr", - default="", + default=DO_NOT_SAVE_FILE, help="Save the alternate labels (default: off).", ) advanced.add_argument( "--alternate_mixing_coeff", type=Path, dest="nbr_mix_coeff", - default="", + default=DO_NOT_SAVE_FILE, help="Save mixing coefficients of alternate labels (default: off).", ) advanced.add_argument( "--seg_means", type=Path, dest="seg_means", - default="", + default=DO_NOT_SAVE_FILE, help="Save means of segmentation labels (default: off).", ) advanced.add_argument( "--alternate_means", type=Path, dest="nbr_means", - default="", + default=DO_NOT_SAVE_FILE, help="Save means of alternate labels (default: off).", ) advanced.add_argument( @@ -387,8 +370,8 @@ def make_arguments(helpformatter: bool = False) -> argparse.ArgumentParser: type=id_type, dest="volume_precision", default=3, - help="Number of digits after dot in summary stats file (default: 3). Use 1 for " - "maximum FreeSurfer compatibility).", + help="Number of digits after dot in summary stats file (default: 3). Use 1 for maximum FreeSurfer " + "compatibility).", ) advanced.add_argument( "--norm_name", @@ -439,8 +422,7 @@ def __add_computed_measure(x: str) -> tuple[bool, str]: default=[], dest="measures", help="Additional Measures to compute based on imported/computed measures:
" - "Cortex, CerebralWhiteMatter, SubCortGray, TotalGray, " - "BrainSegVol-to-eTIV, MaskVol-to-eTIV, SurfaceHoles, " + "Cortex, CerebralWhiteMatter, SubCortGray, TotalGray, BrainSegVol-to-eTIV, MaskVol-to-eTIV, SurfaceHoles, " "EstimatedTotalIntraCranialVol", ) @@ -455,39 +437,33 @@ def __add_imported_measure(x: str) -> tuple[bool, str]: dest="measures", help="Additional Measures to import from the measurefile.
" "Example measures ('all' to import all measures in the measurefile):
" - "BrainSeg, BrainSegNotVent, SupraTentorial, SupraTentorialNotVent, " - "SubCortGray, lhCortex, rhCortex, Cortex, TotalGray, " - "lhCerebralWhiteMatter, rhCerebralWhiteMatter, CerebralWhiteMatter, Mask, " - "SupraTentorialNotVentVox, BrainSegNotVentSurf, VentricleChoroidVol, " - "BrainSegVol-to-eTIV, MaskVol-to-eTIV, lhSurfaceHoles, rhSurfaceHoles, " - "SurfaceHoles, EstimatedTotalIntraCranialVol
" - "Note, 'all' will always be overwritten by any explicitly mentioned " - "measures.", + "BrainSeg, BrainSegNotVent, SupraTentorial, SupraTentorialNotVent, SubCortGray, lhCortex, rhCortex, " + "Cortex, TotalGray, lhCerebralWhiteMatter, rhCerebralWhiteMatter, CerebralWhiteMatter, Mask, " + "SupraTentorialNotVentVox, BrainSegNotVentSurf, VentricleChoroidVol, BrainSegVol-to-eTIV, " + "MaskVol-to-eTIV, lhSurfaceHoles, rhSurfaceHoles, SurfaceHoles, EstimatedTotalIntraCranialVol
" + "Note, 'all' will always be overwritten by any explicitly mentioned measures.", ) measure_parser.add_argument( "--file", type=Path, dest="measurefile", default="brainvol.stats", - help="Default file to read measures (--import ...) from. If the path is " - "relative, it is interpreted as relative to subjects_dir/subject_id from" - "--sd and --subject_id.", + help="Default file to read measures (--import ...) from. If the path is relative, it is interpreted as " + "relative to subjects_dir/subject_id from --sd and --subject_id.", ) measure_parser.add_argument( "--from_seg", type=Path, dest="aseg_replace", default=None, - help="Replace the default segfile to compute measures from by -i/--segfile. " - "This will default to 'mri/aseg.mgz' for --legacy_freesurfer and to the " - "value of -i/--segfile otherwise." + help="Replace the default segfile to compute measures from by -i/--segfile. This will default to " + "'mri/aseg.mgz' for --legacy_freesurfer and to the value of -i/--segfile otherwise." ) def add_two_help_messages(parser: argparse.ArgumentParser) -> None: """ - Adds separate help flags -h and --help to the parser for simple and detailed help. - Both trigger the help action. + Adds separate help flags -h and --help to the parser for simple and detailed help. Both trigger the help action. Parameters ---------- @@ -514,8 +490,8 @@ def _check_arg_path( require_exist: bool = True, ) -> Path: """ - Check an argument that is supposed to be a Path object and finding the absolute - path, which can be derived from the subject_dir. + Check an argument that is supposed to be a Path object and finding the absolute path, which can be derived from the + subject_dir. Parameters ---------- @@ -523,11 +499,10 @@ def _check_arg_path( The arguments object. __attr: str The name of the attribute in the Namespace object. - allow_subject_dir : bool, optional - Whether relative paths are supposed to be understood with respect to - subjects_dir / subject_id (default: True). - require_exist : bool, optional - Raise a ValueError, if the indicated file does not exist (default: True). + allow_subject_dir : bool, default=True + Whether relative paths are supposed to be understood with respect to subjects_dir / subject_id. + require_exist : bool, default=True + Raise a ValueError, if the indicated file does not exist. Returns ------- @@ -537,8 +512,8 @@ def _check_arg_path( Raises ------ ValueError - If attribute does not exist, is not a Path (or convertible to a Path), or if - the file does not exist, but `require_exist` is True. + If attribute does not exist, is not a Path (or convertible to a Path), or if the file does not exist, but + `require_exist` is True. """ if (_attr_val := getattr(__args, __attr), None) is None: raise ValueError(f"No {__attr} passed.") @@ -575,8 +550,8 @@ def _check_arg_defined(attr: str, /, args: argparse.Namespace) -> bool: def check_shape_affine( - img1: "nib.analyze.SpatialImage", - img2: "nib.analyze.SpatialImage", + img1: nibabelImage, + img2: nibabelImage, name1: str, name2: str, ) -> None: @@ -781,11 +756,7 @@ def main(args: argparse.Namespace) -> Literal[0] | str: except ValueError as e: return e.args[0] - threads = getattr(args, "threads", 0) - if threads <= 0: - threads = get_num_threads() - - compute_threads = ThreadPoolExecutor(threads) + compute_threads = thread_executor() # the manager object supports preloading of files (see below) for io parallelization # and calculates the measure @@ -818,9 +789,7 @@ def main(args: argparse.Namespace) -> Literal[0] | str: pv_img, pv_data = _pv if not empty(pvfile_preproc := getattr(args, "pvfile_preproc", None)): - pv_preproc_future = compute_threads.submit( - preproc_image, pvfile_preproc, pv_data, - ) + pv_preproc_future = compute_threads.submit(preproc_image, pvfile_preproc, pv_data) check_shape_affine(seg, pv_img, "segmentation", "pv_guide") if normfile is not None: @@ -837,16 +806,12 @@ def main(args: argparse.Namespace) -> Literal[0] | str: lut = read_lut(lut_file) # manager.lut = lut except FileNotFoundError: - return ( - f"Could not find the ColorLUT in {lut_file}, make sure the --lut " - f"argument is valid." - ) + return f"Could not find the ColorLUT in {lut_file}, make sure the --lut argument is valid." except Exception as exception: return exception.args[0] if measure_only: - # in this mode, we do not output a data table anyways, so no need to compute - # all these PV values. + # in this mode, we do not output a data table anyways, so no need to compute all these PV values. labels, exclude_id = np.zeros((0,), dtype=int), [] else: try: @@ -889,8 +854,8 @@ def main(args: argparse.Namespace) -> Literal[0] | str: manager.compute_non_derived_pv(compute_threads) names = ["nbr", "nbr_means", "seg_means", "mix_coeff", "nbr_mix_coeff"] - save_maps_paths = (getattr(args, n, "") for n in names) - save_maps = any(bool(path) and path != Path() for path in save_maps_paths) + save_maps_paths = (getattr(args, n, DO_NOT_SAVE_FILE) for n in names) + save_maps = any(bool(path) and path != DO_NOT_SAVE_FILE and path != Path() for path in save_maps_paths) save_maps = save_maps and not measure_only if needs_pv_calc: @@ -918,7 +883,7 @@ def main(args: argparse.Namespace) -> Literal[0] | str: return e.args[0] print(f"Brain volume stats written to {segstatsfile}.") duration = (perf_counter_ns() - start) / 1e9 - print(f"Calculation took {duration:.2f} seconds using up to {threads} threads.") + print(f"Calculation took {duration:.2f} seconds using up to {get_num_threads()} threads.") return 0 _io_futures = [] @@ -926,7 +891,7 @@ def main(args: argparse.Namespace) -> Literal[0] | str: table, maps = out dtypes = [np.int16] + [np.float32] * 4 for name, dtype in zip(names, dtypes, strict=False): - if not bool(file := getattr(args, name, "")) or file == Path(): + if not bool(file := getattr(args, name, DO_NOT_SAVE_FILE)) or file == Path() or file == DO_NOT_SAVE_FILE: # skip "fullview"-files that are not defined continue print(f"Saving {name} to {file}...") @@ -976,7 +941,7 @@ def main(args: argparse.Namespace) -> Literal[0] | str: print(f"Partial volume stats for {dataframe.shape[0]} labels written to " f"{segstatsfile}.") duration = (perf_counter_ns() - start) / 1e9 - print(f"Calculation took {duration:.2f} seconds using up to {threads} threads.") + print(f"Calculation took {duration:.2f} seconds using up to {get_num_threads()} threads.") for _io_fut in _io_futures: if (e := _io_fut.exception()) is not None: @@ -1007,8 +972,7 @@ def infer_merged_labels( Returns ------- all_merged_labels : dict[int, Sequence[int]] - The dictionary of all merged labels (via :class:`PVMeasures` as well as - `merged_labels`). + The dictionary of all merged labels (via :class:`PVMeasures` as well as `merged_labels`). """ _merged_labels = {} if not empty(merged_labels): @@ -1747,7 +1711,7 @@ def pv_calc( eps: float = 1e-6, robust_percentage: float | None = None, merged_labels: VirtualLabel | None = None, - threads: int | Executor = -1, + threads: Executor | None = None, return_maps: bool = False, legacy_freesurfer: bool = False, ) -> list[PVStats] | tuple[list[PVStats], dict[str, np.ndarray]]: @@ -1774,14 +1738,13 @@ def pv_calc( Fraction for robust calculation of statistics. merged_labels : VirtualLabel, optional Defines labels to compute statistics for that are. - threads : int, concurrent.futures.Executor, default=-1 - Number of parallel threads to use in calculation, alternatively an executor - object. + threads : concurrent.futures.Executor, optional + Number of parallel threads to use in calculation, alternatively an executor object. + int deprecated: uses FastSurfer.utils.parallel.set_num_threads. return_maps : bool, default=False Returns a dictionary containing the computed maps. legacy_freesurfer : bool, default=False - Whether to use a freesurfer legacy compatibility mode to exactly replicate - freesurfer. + Whether to use a freesurfer legacy compatibility mode to exactly replicate freesurfer. Returns ------- @@ -1859,25 +1822,15 @@ def pv_calc( robust_percentage=robust_percentage, ) - if threads == 0: - raise ValueError("Zero is not a valid number of threads.") - elif isinstance(threads, int) and threads > 0: - nthreads = threads - elif isinstance(threads, Executor | int): - nthreads: int = get_num_threads() - else: - raise TypeError("threads must be int or concurrent.futures.Executor object.") - executor = ThreadPoolExecutor(nthreads) if isinstance(threads, int) else threads - map_kwargs = {"chunksize": 1 if nthreads < 0 else ceil(len(labels) / nthreads)} + executor = threads if isinstance(threads, Executor) else threads + map_kwargs = {"chunksize": 1 if get_num_threads() < 0 else ceil(len(labels) / get_num_threads())} global_stats_future = executor.map(global_stats_filled, all_labels, **map_kwargs) if return_maps: from concurrent.futures import ProcessPoolExecutor if isinstance(executor, ProcessPoolExecutor): - raise NotImplementedError( - "The ProcessPoolExecutor is not compatible with return_maps=True!" - ) + raise NotImplementedError("The ProcessPoolExecutor is not compatible with return_maps=True!") full_nbr_label = np.zeros(seg.shape, dtype=seg.dtype) full_nbr_mean = np.zeros(pv_guide.shape, dtype=float) full_seg_mean = np.zeros(pv_guide.shape, dtype=float) @@ -1906,7 +1859,7 @@ def pv_calc( patch_iters = [range(slc.start, slc.stop, patch_size) for slc in global_crop] # 4 chunks per core num_valid_labels = len(voxel_counts) - map_kwargs["chunksize"] = np.ceil(num_valid_labels / nthreads / 4).item() + map_kwargs["chunksize"] = np.ceil(num_valid_labels / get_num_threads() / 4).item() patch_filter_func = partial(patch_filter, mask=any_border, global_crop=global_crop, patch_size=patch_size) _patches = executor.map(patch_filter_func, product(*patch_iters), **map_kwargs) @@ -1985,8 +1938,8 @@ def calculate_merged_labels( eps: float = 1e-6, ) -> Iterator[PVStats]: """ - Calculate the statistics for meta-labels, i.e. labels based on other labels - (`merge_labels`). Add respective items to `table`. + Calculate the statistics for meta-labels, i.e. labels based on other labels via `merged_labels`. Also added as + respective items to `table`. Parameters ---------- diff --git a/FastSurferCNN/utils/__init__.py b/FastSurferCNN/utils/__init__.py index da1c5dcf9..237a69d47 100644 --- a/FastSurferCNN/utils/__init__.py +++ b/FastSurferCNN/utils/__init__.py @@ -13,28 +13,69 @@ # limitations under the License. __all__ = [ + "AffineMatrix4x4", "checkpoint", "common", + "Image2d", + "Image3d", + "Image4d", "load_config", "logging", "lr_scheduler", "mapper", + "Mask2d", + "Mask3d", + "Mask4d", "meters", "metrics", "misc", + "nibabelImage", + "nibabelHeader", "parser_defaults", - "threads", + "parallel", "Plane", "PlaneAxial", "PlaneCoronal", "PlaneSagittal", "PLANES", + "RotationMatrix3x3", + "ScalarType", + "Shape2d", + "Shape3d", + "Shape4d", + "ShapeType", + "Vector2d", + "Vector3d", ] -from typing import Literal, get_args +from typing import TYPE_CHECKING, Literal, TypedDict, TypeVar +if TYPE_CHECKING: + from nibabel.analyze import SpatialHeader as nibabelHeader + from nibabel.analyze import SpatialImage as nibabelImage +else: + class nibabelImage: ... + class nibabelHeader: ... + +from numpy import bool_, dtype, float_, ndarray, number + +AffineMatrix4x4 = ndarray[tuple[Literal[4], Literal[4]], dtype[float_]] PlaneAxial = Literal["axial"] PlaneCoronal = Literal["coronal"] PlaneSagittal = Literal["sagittal"] Plane = PlaneAxial | PlaneCoronal | PlaneSagittal PLANES: tuple[PlaneAxial, PlaneCoronal, PlaneSagittal] = ("axial", "coronal", "sagittal") +ScalarType = TypeVar("ScalarType", bound=number) +Vector2d = ndarray[tuple[Literal[2]], dtype[float_]] +Vector3d = ndarray[tuple[Literal[3]], dtype[float_]] +Shape2d = tuple[int, int] +Shape3d = tuple[int, int, int] +Shape4d = tuple[int, int, int, int] +ShapeType = TypeVar("ShapeType", bound=tuple[int, ...]) +Image2d = ndarray[Shape2d, dtype[ScalarType]] +Image3d = ndarray[Shape3d, dtype[ScalarType]] +Image4d = ndarray[Shape4d, dtype[ScalarType]] +Mask2d = ndarray[Shape2d, dtype[bool_]] +Mask3d = ndarray[Shape3d, dtype[bool_]] +Mask4d = ndarray[Shape4d, dtype[bool_]] +RotationMatrix3x3 = ndarray[tuple[Literal[3], Literal[3]], dtype[float_]] diff --git a/FastSurferCNN/utils/arg_types.py b/FastSurferCNN/utils/arg_types.py index a9c140fd7..d4caa4aa4 100644 --- a/FastSurferCNN/utils/arg_types.py +++ b/FastSurferCNN/utils/arg_types.py @@ -13,6 +13,7 @@ # limitations under the License. from itertools import permutations, product +from pathlib import Path from typing import Literal, cast import nibabel as nib @@ -279,3 +280,22 @@ def unquote_str(value: str) -> str: if val.startswith("'") and val.endswith("'"): return val[1:-1] return val + + +def path_or_none(a: str) -> Path | None: + """ + Convert a string into None, if it reads "none" or is empty, else convert to a Path. + + Parameters + ---------- + a : str + String to convert to Path. + + Returns + ------- + Path or None + Return None if `a` is empty or case-insensitive "none" else return a as Path. + """ + if a.lower() in ("none", ""): + return None + return Path(a) diff --git a/FastSurferCNN/utils/brainvolstats.py b/FastSurferCNN/utils/brainvolstats.py index a193ac265..e81a46f97 100644 --- a/FastSurferCNN/utils/brainvolstats.py +++ b/FastSurferCNN/utils/brainvolstats.py @@ -19,18 +19,19 @@ import numpy as np +from FastSurferCNN.utils import AffineMatrix4x4, nibabelImage, ShapeType from FastSurferCNN.utils.common import update_docstring +from FastSurferCNN.utils.lta import LTADict +from FastSurferCNN.utils.parallel import SerialExecutor, thread_executor if TYPE_CHECKING: import lapy - import nibabel as nib import pandas as pd from numpy import typing as npt - from CerebNet.datasets.utils import AffineMatrix4x4, LTADict MeasureTuple = tuple[str, str, int | float, str] -ImageTuple = tuple["nib.analyze.SpatialImage", np.ndarray[tuple[int, ...], np.dtype[np.number]]] +ImageTuple = tuple[nibabelImage, np.ndarray[tuple[int, ...], np.dtype[np.number]]] UnitString = Literal["unitless", "mm^3"] MeasureString = Union[str, "Measure"] AnyBufferType = Union[ @@ -48,13 +49,13 @@ "lapy.TriaMesh", "np.ndarray", "pd.DataFrame", + "LTADict", ]) DerivedAggOperation = Literal["sum", "ratio", "by_vox_vol"] AnyMeasure = Union["AbstractMeasure", str] PVMode = Literal["vox", "pv"] ClassesType = Sequence[int] -_TShape = TypeVar("_TShape", bound=tuple[int, ...]) -CondType = Callable[["np.ndarray[_TShape, np.dtype[np.number]]"], "np.ndarray[_TShape, np.dtype[bool]]"] +CondType = Callable[["np.ndarray[_TShape, np.dtype[np.number]]"], "np.ndarray[_TShape, np.dtype[np.bool_]]"] ClassesOrCondType = ClassesType | CondType MaskSign = Literal["abs", "pos", "neg"] @@ -179,7 +180,7 @@ def read_lta_transform_file(path: Path) -> "AffineMatrix4x4": matrix : AffineMatrix4x4 Matrix of shape (4, 4). """ - from CerebNet.datasets.utils import read_lta + from FastSurferCNN.utils.lta import read_lta return read_lta(path)["lta"][0, 0] @@ -241,11 +242,11 @@ def read_transform_file(path: Path) -> "AffineMatrix4x4": def mask_in_array( - arr: np.ndarray[_TShape, np.dtype[np.unsignedinteger]], + arr: np.ndarray[ShapeType, np.dtype[np.unsignedinteger]], items: "npt.ArrayLike", /, max_index: np.unsignedinteger | None = None, -) -> np.ndarray[_TShape, np.dtype[bool]]: +) -> np.ndarray[ShapeType, np.dtype[np.bool_]]: """ Efficient function to generate a mask of elements in `arr`, which are also in items. @@ -287,7 +288,7 @@ def mask_in_array( def mask_not_in_array( - arr: np.ndarray[_TShape, np.dtype[np.unsignedinteger]], + arr: np.ndarray[ShapeType, np.dtype[np.unsignedinteger]], items: "npt.ArrayLike", /, max_index: np.unsignedinteger | None = None, @@ -326,7 +327,7 @@ def mask_not_in_array( def __infer_check_max_index( - arr: np.ndarray[_TShape, np.dtype[np.unsignedinteger]], + arr: np.ndarray[ShapeType, np.dtype[np.unsignedinteger]], items: "npt.ArrayLike", max_index: int | None, ) -> int: @@ -342,9 +343,9 @@ def __infer_check_max_index( @update_docstring(left_classes=ASEG_LEFT_CLASSES, right_classes=ASEG_RIGHT_CLASSES) def hemi_masks_from_aseg( - arr: np.ndarray[_TShape, np.dtype[np.integer]], + arr: np.ndarray[ShapeType, np.dtype[np.integer]], window_size: int = 7, -) -> tuple[np.ndarray[_TShape, np.dtype[bool]], np.ndarray[_TShape, np.dtype[bool]]]: +) -> tuple[np.ndarray[ShapeType, np.dtype[np.bool_]], np.ndarray[ShapeType, np.dtype[np.bool_]]]: """ Determine for each voxel if it is more likely left hemisphere or right hemisphere. @@ -381,8 +382,8 @@ def hemi_masks_from_aseg( def __ness(classes): return uniform_filter(mask_in_array(arr, classes).astype(np.float32), size=window_size) - _leftness: np.ndarray[_TShape, np.dtype[np.float32]] - _rightness: np.ndarray[_TShape, np.dtype[np.float32]] + _leftness: np.ndarray[ShapeType, np.dtype[np.float32]] + _rightness: np.ndarray[ShapeType, np.dtype[np.float32]] _leftness, _rightness = _map(__ness, (ASEG_LEFT_CLASSES, ASEG_RIGHT_CLASSES)) @@ -984,7 +985,7 @@ def __init__( # self._erode: int = erode super().__init__(maskfile, self.mask, name, description, unit, read_file) - def mask(self, data: np.ndarray[_TShape, np.number]) -> np.ndarray[_TShape, np.dtype[bool]]: + def mask(self, data: np.ndarray[ShapeType, np.dtype[np.number]]) -> np.ndarray[ShapeType, np.dtype[np.bool_]]: """Generates a mask from data similar to mri_binarize + erosion.""" # if self._sign == "abs": # data = np.abs(data) @@ -1259,8 +1260,8 @@ def read_subject(self, subject_dir: Path) -> bool: Notes ----- - Might trigger a race condition if the function hook `read_subject_on_parents` - depends on this method finishing first, e.g. because of thread availability. + Might trigger a race condition if the function hook `read_subject_on_parents` depends on this method finishing + first, e.g. because of thread availability. """ if super().read_subject(subject_dir): return self.read_subject_on_parents(self._subject_dir) @@ -1454,7 +1455,7 @@ def __init__( legacy_freesurfer : bool, default=False FreeSurfer compatibility mode. """ - from concurrent.futures import Future, ThreadPoolExecutor + from concurrent.futures import Future from copy import deepcopy def _check_measures(x): @@ -1464,15 +1465,7 @@ def _check_measures(x): self._default_measures = deepcopy(self.__DEFAULT_MEASURES) if not isinstance(measures, Sequence) or any(map(_check_measures, measures)): raise ValueError("measures must be sequences of str.") - if executor is None: - self._executor = ThreadPoolExecutor(8) - elif isinstance(executor, ThreadPoolExecutor): - self._executor = executor - else: - raise TypeError( - "executor must be a futures.concurrent.ThreadPoolExecutor to ensure " - "proper multitask behavior." - ) + self._io_futures: list[Future] = [] self.__update_context: list[AbstractMeasure] = [] self._on_missing = on_missing @@ -1517,7 +1510,7 @@ def _check_measures(x): @property def executor(self) -> Executor: - return self._executor + return thread_executor() # @property # def lut(self) -> Optional["pd.DataFrame"]: @@ -1648,8 +1641,7 @@ def __getitem__(self, key: str) -> AbstractMeasure: def start_read_subject(self, subject_dir: Path) -> None: """ - Start the threads to read the subject in subject_dir, pairs with - `wait_read_subject`. + Start the threads to read the subject in subject_dir, pairs with `wait_read_subject`. Parameters ---------- @@ -1776,7 +1768,7 @@ def _read(measure: AbstractMeasure) -> bool: # calls read_subject on all measures, redundant io operations are # handled/skipped through Manager.make_read_hook and the internal # caching of files within the _cache attribute of Manager. - self._io_futures.append(self._executor.submit(_read, x)) + self._io_futures.append(thread_executor().submit(_read, x)) return True def extract_key_args(self, measure: str) -> tuple[str, list[str]]: @@ -1855,7 +1847,7 @@ def read_wrapper(file: Path, blocking: bool = True) -> T_BufferType | None: if blocking: out = read_func(file) else: - out = self._executor.submit(read_func, file) + out = thread_executor().submit(read_func, file) self._cache[file] = out if not blocking: return @@ -2128,7 +2120,8 @@ def default(self, key: str) -> AbstractMeasure: ) elif key in ("lhWM-hypointensities", "rhWM-hypointensities"): # lateralized counting of class 77 WM hypo intensities - def mask_77_lat(arr: np.ndarray[_TShape, np.dtype[np.integer]]) -> np.ndarray[_TShape, np.dtype[bool]]: + def mask_77_lat(arr: np.ndarray[ShapeType, np.dtype[np.integer]]) \ + -> np.ndarray[ShapeType, np.dtype[np.bool_]]: """ This function returns a lateralized mask of hypo-WM (class 77). @@ -2147,12 +2140,35 @@ def mask_77_lat(arr: np.ndarray[_TShape, np.dtype[np.integer]]) -> np.ndarray[_T f"Volume of {side} White matter hypointensities", "mm^3" ) + elif key in ("lhFornix", "rhFornix"): + # lateralized counting of class 192 Fornix + def mask_192_lat(arr: np.ndarray[ShapeType, np.dtype[np.integer]]) \ + -> np.ndarray[ShapeType, np.dtype[np.bool_]]: + """ + This function returns a lateralized mask of the Fornix (class 192). + + This is achieved by looking at surrounding labels and associating them + with left or right (this is not 100% robust when there is no clear + classes with left aseg labels present, but it is cheap to perform). + """ + mask = arr == 192 + side_index = {"Left": 0, "Right": 1}[side] + return np.logical_and(hemi_masks_from_aseg(arr)[side_index], mask) + + return VolumeMeasure( + self._seg_from_file, + mask_192_lat, + f"{hemi}Fornix", + f"Volume of the {side} Fornix", + "mm^3" + ) elif key in ("lhCerebralWhiteMatter", "rhCerebralWhiteMatter"): # SurfaceVolume # 9/10 => l/rCerebralWM parents = [ f"{hemi}WhiteMatterVol", f"{hemi}WM-hypointensities", + f"{hemi}Fornix", (0.5, "CorpusCallosumVol"), ] return DerivedMeasure( @@ -2390,19 +2406,10 @@ def compute_non_derived_pv( For each non-derived and non-PV measure, a future object that is associated with the call to the measure. """ - - def run(f: Callable[[], int | float]) -> Future[int | float]: - out = Future() - out.set_result(f()) - return out - - if isinstance(compute_threads, Executor): - run = compute_threads.submit + run = compute_threads.submit if isinstance(compute_threads, Executor) else SerialExecutor().submit invalid_types = (DerivedMeasure, PVMeasure) - self._compute_futures = [ - run(this) for this in self.values() if not isinstance(this, invalid_types) - ] + self._compute_futures = [run(this) for this in self.values() if not isinstance(this, invalid_types)] return self._compute_futures def needs_pv_calculation(self) -> bool: diff --git a/FastSurferCNN/utils/common.py b/FastSurferCNN/utils/common.py index 1aaae5822..da6ecf9fb 100644 --- a/FastSurferCNN/utils/common.py +++ b/FastSurferCNN/utils/common.py @@ -186,9 +186,8 @@ def handle_cuda_memory_exception(exception: BaseException) -> bool: if message.startswith("CUDA out of memory. "): LOGGER.critical("ERROR - INSUFFICIENT GPU MEMORY") LOGGER.info( - "The memory requirements exceeds the available GPU memory, try using a " - "smaller batch size (--batch_size ) and/or view aggregation on the " - "cpu (--viewagg_device 'cpu')." + "The memory requirements exceeds the available GPU memory, try using a smaller batch size " + "(--batch_size ) and/or view aggregation on the cpu (--viewagg_device 'cpu')." ) LOGGER.info( "Note: View Aggregation on the GPU is particularly memory-hungry at " @@ -725,9 +724,8 @@ def __init__( self._out_segfile = getattr(self, "_segfile_", None) if self._out_segfile is None: raise RuntimeError( - "The segmentation output file is not set, it should be either " - "'segfile' (which gets populated from args.segfile), or a keyword " - "argument to __init__, e.g. `SubjectList(args, subseg='subseg_param', " + "The segmentation output file is not set, it should be either 'segfile' (which gets populated from " + "args.segfile), or a keyword argument to `__init__`, e.g. `SubjectList(args, subseg='subseg_param', " "out_filename='subseg')`." ) @@ -735,9 +733,8 @@ def __init__( self._out_dir = getattr(args, "out_dir", None) or getattr(args, "in_dir", None) if self._out_dir in [None, ""] and not os.path.isabs(self._out_segfile): msg = ( - "Please specify, where the segmentation output should be stored by " - "either the {sd[flag]} flag (output subject directory, this can be " - "same as input directory) or an absolute path to the " + "Please specify, where the segmentation output should be stored by either the {sd[flag]} flag (output " + "subject directory, this can be same as input directory) or an absolute path to the " "{asegdkt_segfile[flag]} output segmentation volume." ) raise RuntimeError(msg.format(**self._flags)) diff --git a/FastSurferCNN/utils/lta.py b/FastSurferCNN/utils/lta.py new file mode 100755 index 000000000..47de9a7c7 --- /dev/null +++ b/FastSurferCNN/utils/lta.py @@ -0,0 +1,230 @@ +#!/usr/bin/env python3 + +# Copyright 2021 Image Analysis Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +from typing import Literal, TypedDict + +import numpy.typing as npt + +from FastSurferCNN.utils import AffineMatrix4x4 + + +class LTADict(TypedDict): + """ + Typed dictionary containing all the information from an LTA transform file. + + Attributes + ---------- + type: int + nxforms: int + mean: list[float] + sigma: float + lta: AffineMatrix4x4 + src_valid: int + src_filename: str + src_volume: list[int] + src_voxelsize: list[float] + src_xras: list[float] + src_yras: list[float] + src_zras: list[float] + src_cras: list[float] + dst_valid: int + dst_filename: str + dst_volume: list[int] + dst_voxelsize: list[float] + dst_xras: list[float] + dst_yras: list[float] + dst_zras: list[float] + dst_cras: list[float] + src: AffineMatrix4x4 + dst: AffineMatrix4x4 + """ + type: int + nxforms: int + mean: list[float] + sigma: float + lta: AffineMatrix4x4 + src_valid: int + src_filename: str + src_volume: list[int] + src_voxelsize: list[float] + src_xras: list[float] + src_yras: list[float] + src_zras: list[float] + src_cras: list[float] + dst_valid: int + dst_filename: str + dst_volume: list[int] + dst_voxelsize: list[float] + dst_xras: list[float] + dst_yras: list[float] + dst_zras: list[float] + dst_cras: list[float] + src: AffineMatrix4x4 + dst: AffineMatrix4x4 + + +# Collection of functions related to FreeSurfer's LTA (linear transform array) files: +def write_lta( + filename: Path | str, + affine: npt.ArrayLike, + src_fname: Path | str, + src_header: dict, + dst_fname: Path | str, + dst_header: dict +) -> None: + """ + Write linear transform array info to an .lta file. + + Parameters + ---------- + filename : Path, str + File to write on. + affine : npt.ArrayLike + Linear transform array to be saved. + src_fname : Path, str + Source filename. + src_header : Dict + Source header. + dst_fname : Path, str + Destination filename. + dst_header : Dict + Destination header. + + Raises + ------ + ValueError + Header format missing field (Source or Destination). + """ + import getpass + from datetime import datetime + + fields = ("dims", "delta", "Mdc", "Pxyz_c") + for field in fields: + if field not in src_header: + raise ValueError(f"write_lta Error: src_header format missing field: {field}") + if field not in dst_header: + raise ValueError(f"write_lta Error: dst_header format missing field: {field}") + + src_dims = str(src_header["dims"][0:3]) + src_vsize = str(src_header["delta"][0:3]) + src_v2r = src_header["Mdc"] + src_c = src_header["Pxyz_c"] + + dst_dims = str(dst_header["dims"][0:3]) + dst_vsize = str(dst_header["delta"][0:3]) + dst_v2r = dst_header["Mdc"] + dst_c = dst_header["Pxyz_c"] + + f = open(filename, "w") + f.write( + (f"# transform file {filename}\n" + f"# created by {getpass.getuser()} on {datetime.now().ctime()}\n\n" + "type = 1 # LINEAR_RAS_TO_RAS\n" + "nxforms = 1\n" + "mean = 0.0 0.0 0.0\n" + "sigma = 1.0\n" + "1 4 4\n" + f"{affine}\n" + "src volume info\n" + "valid = 1 # volume info valid\n" + f"filename = {src_fname}\n" + f"volume = {src_dims}\n" + f"voxelsize = {src_vsize}\n" + f"xras = {src_v2r[0, :]}\n" + f"yras = {src_v2r[1, :]}\n" + f"zras = {src_v2r[2, :]}\n" + f"cras = {src_c}\n" + "dst volume info\n" + "valid = 1 # volume info valid\n" + f"filename = {dst_fname}\n" + f"volume = {dst_dims}\n" + f"voxelsize = {dst_vsize}\n" + f"xras = {dst_v2r[0, :]}\n" + f"yras = {dst_v2r[1, :]}\n" + f"zras = {dst_v2r[2, :]}\n" + f"cras = {dst_c}\n").replace("[", "").replace("]", "") + ) + f.close() + + +def read_lta(file: Path | str) -> LTADict: + """Read the LTA info.""" + import re + from functools import partial + + import numpy as np + parameter_pattern = re.compile("^\\s*([^=]+)\\s*=\\s*([^#]*)\\s*(#.*)") + vol_info_pattern = re.compile("^(.*) volume info$") + shape_pattern = re.compile("^(\\s*\\d+)+$") + matrix_pattern = re.compile("^(-?\\d+\\.\\S+\\s+)+$") + + def _vector(_a: str, dtype: npt.DTypeLike = float, count: int = -1) -> npt.DTypeLike: + return np.fromstring(_a, dtype=dtype, count=count, sep=" ").tolist() + + parameters = { + "type": int, + "nxforms": int, + "mean": partial(_vector, dtype=float, count=3), + "sigma": float, + "subject": str, + "fscale": float, + } + vol_info_par = { + "valid": int, + "filename": str, + "volume": partial(_vector, dtype=int, count=3), + "voxelsize": partial(_vector, dtype=float, count=3), + **{f"{c}ras": partial(_vector, dtype=float) for c in "xyzc"} + } + + with open(file) as f: + lines = f.readlines() + + items = [] + shape_lines = [] + matrix_lines = [] + section = "" + for i, line in enumerate(lines): + if line.strip() == "": + continue + if hits := parameter_pattern.match(line): + name = hits.group(1) + if section and name in vol_info_par: + items.append((f"{section}_{name}", vol_info_par[name](hits.group(2)))) + elif name in parameters: + section = "" + items.append((name, parameters[name](hits.group(2)))) + else: + raise NotImplementedError(f"Unrecognized type string in lta-file " + f"{file}:{i+1}: '{name}'") + elif hits := vol_info_pattern.match(line): + section = hits.group(1) + # not a parameter line + elif shape_pattern.search(line): + shape_lines.append(np.fromstring(line, dtype=int, count=-1, sep=" ")) + elif matrix_pattern.search(line): + matrix_lines.append(np.fromstring(line, dtype=float, count=-1, sep=" ")) + + shape_lines = list(map(tuple, shape_lines)) + lta = dict(items) + if lta["nxforms"] != len(shape_lines): + raise OSError("Inconsistent lta format: nxforms inconsistent with shapes.") + if len(shape_lines) > 1 and np.any(np.not_equal([shape_lines[0]], shape_lines[1:])): + raise OSError(f"Inconsistent lta format: shapes inconsistent {shape_lines}") + lta_matrix = np.asarray(matrix_lines).reshape((-1,) + shape_lines[0].shape) + lta["lta"] = lta_matrix + return lta diff --git a/FastSurferCNN/utils/metrics.py b/FastSurferCNN/utils/metrics.py index 404aa6c77..34bdfba79 100644 --- a/FastSurferCNN/utils/metrics.py +++ b/FastSurferCNN/utils/metrics.py @@ -115,7 +115,7 @@ class DiceScore: A callable to update the accumulator. Method's signature is `(accumulator, output)`. For example, to compute arithmetic mean value, `op = lambda a, x: a + x`. output_transform : callable, optional - A callable that is used to transform the :class:`~ignite.engine.Engine`'s `process_function`'s output into the + A callable that is used to transform the :class:`~ignite.engine.Engine`\'s `process_function`\'s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs. device : str or torch.device, optional diff --git a/FastSurferCNN/utils/parallel.py b/FastSurferCNN/utils/parallel.py index a442067f5..f501bf43e 100644 --- a/FastSurferCNN/utils/parallel.py +++ b/FastSurferCNN/utils/parallel.py @@ -27,6 +27,7 @@ "serial_executor", "set_num_threads", "SerialExecutor", + "shutdown_executors", "thread_executor", ] diff --git a/FastSurferCNN/utils/parser_defaults.py b/FastSurferCNN/utils/parser_defaults.py index 93d284ffb..9dd483ad4 100644 --- a/FastSurferCNN/utils/parser_defaults.py +++ b/FastSurferCNN/utils/parser_defaults.py @@ -256,8 +256,7 @@ class SubjectDirectoryConfig: "device": __arg( "--device", default="auto", - help="Select device to run inference on: cpu, or cuda (= Nvidia gpu) or specify a certain gpu (e.g. cuda:1), " - "Default: auto", + help="Select device to run inference on: cpu, or cuda (= Nvidia gpu) or specify a certain gpu (e.g. cuda:1)" ), "viewagg_device": __arg( "--viewagg_device", diff --git a/FastSurferCNN/utils/run_tools.py b/FastSurferCNN/utils/run_tools.py index fd3b85b72..e38056b31 100644 --- a/FastSurferCNN/utils/run_tools.py +++ b/FastSurferCNN/utils/run_tools.py @@ -151,7 +151,7 @@ def next_message(self, timeout: float) -> MessageBuffer: def finish(self, timeout: float = None) -> MessageBuffer: """ - `finish`'s behavior is similar to `subprocess.dry_run`. + `finish`\'s behavior is similar to `subprocess.dry_run`. `finish` waits `timeout` seconds, and forces termination after. By default, waits unlimited `timeout=None`. In either case, all messages in stdout and diff --git a/HypVINN/data_loader/dataset.py b/HypVINN/data_loader/dataset.py index 15db9f366..f301f4afa 100644 --- a/HypVINN/data_loader/dataset.py +++ b/HypVINN/data_loader/dataset.py @@ -125,8 +125,8 @@ def __init__( else: logger.info( f"For inference T1 block weight was set to: " - f"{self.weight_factor.numpy()[0]} and the T2 block was set to: " - f"{self.weight_factor.numpy()[1]}") + f"{self.weight_factor.cpu().numpy()[0]} and the T2 block was set to: " + f"{self.weight_factor.cpu().numpy()[1]}") def _standarized_img(self, orig_data: np.ndarray, orig_zoom: npt.NDArray[float], modality: np.ndarray) -> np.ndarray: diff --git a/README.md b/README.md index eda2acd25..8ff862d8f 100644 --- a/README.md +++ b/README.md @@ -24,16 +24,20 @@ Modules (all run by default): - the core, outputs anatomical segmentation and cortical parcellation and statistics of 95 classes, mimics FreeSurfer’s DKTatlas. - requires a T1w image ([notes on input images](#requirements-to-input-images)), supports high-res (up to 0.7mm, experimental beyond that). - performs bias-field correction and calculates volume statistics corrected for partial volume effects (skipped if `--no_biasfield` is passed). -2. `cereb:` [CerebNet](CerebNet/README.md) for cerebellum sub-segmentation (deactivate with `--no_cereb`) +2. `cc`: [CorpusCallosum](CorpusCallosum/README.md) for corpus callosum segmentation and shape analysis (deactivate with `--no_cc`) + - requires `asegdkt_segfile` (segmentation) and conformed mri (orig.mgz), outputs CC segmentation, thickness, and shape metrics. + - standardizes brain orientation based on AC/PC landmarks (orient_volume.lta). +3. `cereb:` [CerebNet](CerebNet/README.md) for cerebellum sub-segmentation (deactivate with `--no_cereb`) - requires `asegdkt_segfile`, outputs cerebellar sub-segmentation with detailed WM/GM delineation. - requires a T1w image ([notes on input images](#requirements-to-input-images)), which will be resampled to 1mm isotropic images (no native high-res support). - calculates volume statistics corrected for partial volume effects (skipped if `--no_biasfield` is passed). -3. `hypothal`: [HypVINN](HypVINN/README.md) for hypothalamus subsegmentation (deactivate with `--no_hypothal`) +4. `hypothal`: [HypVINN](HypVINN/README.md) for hypothalamus subsegmentation (deactivate with `--no_hypothal`) - outputs a hypothalamic subsegmentation including 3rd ventricle, c. mammilare, fornix and optic tracts. - a T1w image is highly recommended ([notes on input images](#requirements-to-input-images)), supports high-res (up to 0.7mm, but experimental beyond that). - allows the additional passing of a T2w image with `--t2 `, which will be registered to the T1w image (see `--reg_mode` option). - calculates volume statistics corrected for partial volume effects based on the T1w image (skipped if `--no_bias_field` is passed). + ### Surface reconstruction - approximately 60-90 minutes, `--surf_only` runs only [the surface part](recon_surf/README.md). - supports high-resolution images (up to 0.7mm, experimental beyond that). @@ -125,6 +129,8 @@ All the examples can be found here: [FASTSURFER_EXAMPLES](doc/overview/EXAMPLES. Modules output can be found here: [FastSurfer_Output_Files](doc/overview/OUTPUT_FILES.md) - [Segmentation module](doc/overview/OUTPUT_FILES.md#segmentation-module) - [Cerebnet module](doc/overview/OUTPUT_FILES.md#cerebnet-module) +- [HypVINN module](doc/overview/OUTPUT_FILES.md#hypvinn-module) +- [Corpus Callosum module](doc/overview/OUTPUT_FILES.md#corpus-callosum-module) - [Surface module](doc/overview/OUTPUT_FILES.md#surface-module) @@ -146,7 +152,7 @@ The default device is the GPU. The view-aggregation device can be switched to CP ## Expert usage Individual modules and the surface pipeline can be run independently of the full pipeline script documented in this documentation. -This is documented in READMEs in subfolders, for example: [whole brain segmentation only with FastSurferVINN](FastSurferCNN/README.md), [cerebellum sub-segmentation](CerebNet/README.md), [hypothalamic sub-segmentation](HypVINN/README.md) and [surface pipeline only (recon-surf)](recon_surf/README.md). +This is documented in READMEs in subfolders, for example: [whole brain segmentation only with FastSurferVINN](FastSurferCNN/README.md), [cerebellum sub-segmentation](CerebNet/README.md), [hypothalamic sub-segmentation](HypVINN/README.md), [corpus callosum analysis](CorpusCallosum/README.md) and [surface pipeline only (recon-surf)](recon_surf/README.md). Specifically, the segmentation modules feature options for optimized parallelization of batch processing. diff --git a/Tutorial/README.md b/Tutorial/README.md index 39989682f..200721eb4 100644 --- a/Tutorial/README.md +++ b/Tutorial/README.md @@ -58,7 +58,7 @@ It is normally recommended to run your set ups in separate virtual environments #### 2. Anaconda You can install anaconda via curl with the following command: ```bash -# The version of Anaconda may be different depending on when you are installing` +# The version of Anaconda may be different depending on when you are installing curl -O https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh sh Miniconda3-latest-Linux-x86_64.sh # and follow the prompts. The defaults are generally good. diff --git a/doc/api/CorpusCallosum.data.rst b/doc/api/CorpusCallosum.data.rst new file mode 100644 index 000000000..a89128e20 --- /dev/null +++ b/doc/api/CorpusCallosum.data.rst @@ -0,0 +1,11 @@ +CorpusCallosum.data +=================== + +.. currentmodule:: CorpusCallosum.data + +.. autosummary:: + :toctree: generated/ + + constants + fsaverage_cc_template + read_write diff --git a/doc/api/CorpusCallosum.localization.rst b/doc/api/CorpusCallosum.localization.rst new file mode 100644 index 000000000..9c6c3b400 --- /dev/null +++ b/doc/api/CorpusCallosum.localization.rst @@ -0,0 +1,9 @@ +CorpusCallosum.localization +============================= + +.. currentmodule:: CorpusCallosum.localization + +.. autosummary:: + :toctree: generated/ + + inference diff --git a/doc/api/CorpusCallosum.rst b/doc/api/CorpusCallosum.rst new file mode 100644 index 000000000..7d9152e5b --- /dev/null +++ b/doc/api/CorpusCallosum.rst @@ -0,0 +1,11 @@ +CorpusCallosum +============== + +.. currentmodule:: CorpusCallosum + +.. autosummary:: + :toctree: generated/ + + fastsurfer_cc + cc_visualization + paint_cc_into_pred diff --git a/doc/api/CorpusCallosum.segmentation.rst b/doc/api/CorpusCallosum.segmentation.rst new file mode 100644 index 000000000..0269688bf --- /dev/null +++ b/doc/api/CorpusCallosum.segmentation.rst @@ -0,0 +1,10 @@ +CorpusCallosum.segmentation +============================ + +.. currentmodule:: CorpusCallosum.segmentation + +.. autosummary:: + :toctree: generated/ + + inference + segmentation_postprocessing diff --git a/doc/api/CorpusCallosum.shape.rst b/doc/api/CorpusCallosum.shape.rst new file mode 100644 index 000000000..f4c059e3f --- /dev/null +++ b/doc/api/CorpusCallosum.shape.rst @@ -0,0 +1,15 @@ +CorpusCallosum.shape +==================== + +.. currentmodule:: CorpusCallosum.shape + +.. autosummary:: + :toctree: generated/ + + postprocessing + mesh + metrics + thickness + subsegment_contour + endpoint_heuristic + contour diff --git a/doc/api/CorpusCallosum.transforms.rst b/doc/api/CorpusCallosum.transforms.rst new file mode 100644 index 000000000..14756a92e --- /dev/null +++ b/doc/api/CorpusCallosum.transforms.rst @@ -0,0 +1,10 @@ +CorpusCallosum.transforms +=========================== + +.. currentmodule:: CorpusCallosum.transforms + +.. autosummary:: + :toctree: generated/ + + localization + segmentation diff --git a/doc/api/CorpusCallosum.utils.rst b/doc/api/CorpusCallosum.utils.rst new file mode 100644 index 000000000..33fe5e045 --- /dev/null +++ b/doc/api/CorpusCallosum.utils.rst @@ -0,0 +1,12 @@ +CorpusCallosum.utils +==================== + +.. currentmodule:: CorpusCallosum.utils + +.. autosummary:: + :toctree: generated/ + + checkpoint + mapping_helpers + types + visualization diff --git a/doc/api/FastSurferCNN.utils.rst b/doc/api/FastSurferCNN.utils.rst index 9b0b00f6e..b9fe7b984 100644 --- a/doc/api/FastSurferCNN.utils.rst +++ b/doc/api/FastSurferCNN.utils.rst @@ -13,6 +13,7 @@ FastSurferCNN.utils load_config logging lr_scheduler + lta mapper meters metrics diff --git a/doc/api/index.rst b/doc/api/index.rst index 546cdf4fa..fd606a8ba 100644 --- a/doc/api/index.rst +++ b/doc/api/index.rst @@ -16,6 +16,13 @@ FastSurfer API CerebNet.datasets.rst CerebNet.models.rst CerebNet.utils.rst + CorpusCallosum.rst + CorpusCallosum.data.rst + CorpusCallosum.localization.rst + CorpusCallosum.segmentation.rst + CorpusCallosum.shape.rst + CorpusCallosum.transforms.rst + CorpusCallosum.utils.rst HypVINN.rst HypVINN.dataloader.rst HypVINN.models.rst diff --git a/doc/api/recon_surf.rst b/doc/api/recon_surf.rst index 54403fb94..0387d24ed 100644 --- a/doc/api/recon_surf.rst +++ b/doc/api/recon_surf.rst @@ -11,17 +11,11 @@ recon_surf align_seg create_annotation fs_balabels - lta map_surf_label N4_bias_correct - paint_cc_into_pred rewrite_oriented_surface rewrite_mc_surface rotate_sphere sample_parc smooth_aparc spherically_project_wrapper - - - - diff --git a/doc/conf.py b/doc/conf.py index be1ed6383..0f4728b53 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -13,8 +13,7 @@ import os from pathlib import Path -# here i added the relative path because sphinx was not able -# to locate FastSurferCNN module directly for autosummary +# relative path so sphinx can locate the different modules directly for autosummary sys.path.append(os.path.dirname(__file__) + "/..") sys.path.append(os.path.dirname(__file__) + "/../recon_surf") sys.path.append(os.path.dirname(__file__) + "/sphinx_ext") @@ -259,7 +258,7 @@ def linkcode_resolve(domain, info): # myst_ref_domains = ["myst", "std", "py"] -_re_script_dirs = "fastsurfercnn|cerebnet|recon_surf|hypvinn" +_re_script_dirs = "fastsurfercnn|cerebnet|recon_surf|hypvinn|corpuscallosum" _up = "^/\\.\\./" _end = "(\\.md)?(#.*)?$" diff --git a/doc/overview/FLAGS.md b/doc/overview/FLAGS.md index 3f06d74dd..735136fbc 100644 --- a/doc/overview/FLAGS.md +++ b/doc/overview/FLAGS.md @@ -6,7 +6,7 @@ The `*fastsurfer-flags*` will usually at least include the subject directory (`- ```bash ... --sd /output --sid test_subject --t1 /data/test_subject_t1.nii.gz --3T ``` -Additionally, you can use `--seg_only` or `--surf_only` to only run a part of the pipeline or `--no_biasfield`, `--no_cereb` and `--no_asegdkt` to switch off individual segmentation modules. +Additionally, you can use `--seg_only` or `--surf_only` to only run a part of the pipeline or `--no_biasfield`, `--no_cereb`, `--no_hypothal`, `--no_cc`, and `--no_asegdkt` to switch off individual segmentation modules. Here, we have also added the `--3T` flag, which tells FastSurfer to register against the 3T atlas which is only relevant for the ICV estimation (eTIV). In the following, we give an overview of the most important options. You can view a [full list of options](FLAGS.md#full-list-of-flags) with @@ -30,6 +30,8 @@ In the following, we give an overview of the most important options. You can vie * `--device`: Select device for neural network segmentation (_auto_, _cpu_, _cuda_, _cuda:_, _mps_), where cuda means Nvidia GPU, you can select which one e.g. "cuda:1". Default: "auto", check GPU and then CPU. "mps" is for native MAC installs to use the Apple silicon (M-chip) GPU. * `--asegdkt_segfile`: Name of the segmentation file, which includes the aparc+DKTatlas-aseg segmentations. Requires an ABSOLUTE Path! Default location: \$SUBJECTS_DIR/\$sid/mri/aparc.DKTatlas+aseg.deep.mgz * `--no_cereb`: Switch off the cerebellum sub-segmentation. +* `--no_hypothal`: Skip the hypothalamus segmentation. +* `--no_cc`: Skip the segmentation and analysis of the corpus callosum. * `--cereb_segfile`: Name of the cerebellum segmentation file. If not provided, this intermediate DL-based segmentation will not be stored, but only the merged segmentation will be stored (see --main_segfile ). Requires an ABSOLUTE Path! Default location: \$SUBJECTS_DIR/\$sid/mri/cerebellum.CerebNet.nii.gz * `--no_biasfield`: Deactivate the biasfield correction and calculation of partial volume-corrected statistics in the segmentation modules. * `--native_image` or `--keepgeom`: **Only supported for `--seg_only`**, segment in native image space (keep orientation, image size and voxel size of the input image), this also includes experimental support for anisotropic images (no extreme anisotropy). diff --git a/doc/overview/OUTPUT_FILES.md b/doc/overview/OUTPUT_FILES.md index 87416ed30..8d7ab7b2a 100644 --- a/doc/overview/OUTPUT_FILES.md +++ b/doc/overview/OUTPUT_FILES.md @@ -15,6 +15,31 @@ The segmentation module outputs the files shown in the table below. The two prim | scripts | deep-seg.log | asegdkt | logfile | | stats | aseg+DKT.stats | asegdkt | table of cortical and subcortical segmentation statistics | + +## Corpus Callosum module + +The Corpus Callosum module outputs the files in the table shown below. It creates detailed segmentations and shape analysis of the corpus callosum. + +| directory | filename | module | description | +|:----------------|--------------------------------|--------|--------------------------------------------------------------------------------------------------------------| +| mri | callosum_seg_upright.mgz | cc | corpus callosum segmentation in upright space | +| mri | callosum_seg_aseg_space.mgz | cc | corpus callosum segmentation in conformed image orientation | +| mri | callosum_seg_soft.mgz | cc | corpus callosum soft labels | +| mri | fornix_seg_soft.mgz | cc | fornix soft labels | +| mri | background_seg_soft.mgz | cc | background soft labels | +| mri/transforms | cc_up.lta | cc | transform from original to upright space | +| mri/transforms | orient_volume.lta | cc | transform to standardized space | +| stats | callosum.CC.midslice.json | cc | measurements from the middle sagittal slice (landmarks, area, thickness, etc.) | +| stats | callosum.CC.all_slices.json | cc | comprehensive per-slice analysis (only when using `--slice_selection all`) | +| qc_snapshots | callosum.png | cc | debug visualization of contours and thickness | +| qc_snapshots | callosum_thickness.png | cc | 3D thickness visualization (with `--slice_selection all`) | +| qc_snapshots | corpus_callosum.html | cc | interactive 3D mesh visualization (with `--slice_selection all`) | +| surf | callosum.surf | cc | FreeSurfer surface format (with `--slice_selection all`) | +| surf | callosum.thickness.w | cc | FreeSurfer overlay file containing thickness values (with `--slice_selection all`) | +| surf | callosum_mesh.vtk | cc | VTK format mesh file for 3D visualization (with `--slice_selection all`) | + + + ## Cerebnet module The cerebellum module outputs the files in the table shown below. Unless switched off by the `--no_cereb` argument, this module is automatically run whenever the segmentation module is run. It adds two files, an image with the sub-segmentation of the cerebellum and a text file with summary statistics. @@ -73,4 +98,4 @@ The primary output files are pial, white, and inflated surface files, the thickn | stats | lh.aparc.DKTatlas.mapped.stats, rh.aparc.DKTatlas.mapped.stats | surface | table of cortical parcellation statistics, mapped from ASEGDKT segmentation to the surface | | stats | lh.curv.stats, rh.curv.stats | surface | table of curvature statistics | | stats | wmparc.DKTatlas.mapped.stats | surface | table of white matter segmentation statistics | -| scripts | recon-all.log | surface | logfile | \ No newline at end of file +| scripts | recon-all.log | surface | logfile | diff --git a/doc/overview/index.rst b/doc/overview/index.rst index e41f65932..2fca45ff3 100644 --- a/doc/overview/index.rst +++ b/doc/overview/index.rst @@ -10,6 +10,7 @@ User Guide EXAMPLES.md FLAGS.md OUTPUT_FILES.md + modules/index docker SINGULARITY.md MACOS.md diff --git a/doc/overview/modules/CC.md b/doc/overview/modules/CC.md new file mode 100644 index 000000000..5795136cd --- /dev/null +++ b/doc/overview/modules/CC.md @@ -0,0 +1,127 @@ +# Corpus Callosum Pipeline + +A deep learning-based pipeline for automated segmentation, analysis, and shape analysis of the corpus callosum in brain MRI scans. +Also segments the fornix, localizes the anterior and posterior commissure (AC and PC) and standardizes the orientation of the brain. + +## Overview + +This pipeline combines localization and segmentation deep learning models to: +1. Detect AC (Anterior Commissure) and PC (Posterior Commissure) points +2. Extract and align midplane slices +3. Segment the corpus callosum +4. Perform advanced morphometry for corpus callosum, including subdivision, thickness analysis, and various shape metrics +5. Generate visualizations and measurements + +## Analysis Modes + +The pipeline supports different analysis modes that determine the type of template data generated. + +### 3D Analysis + +When running the main pipeline with `--slice_selection all` and `--save_template`, a complete 3D template is generated: + +```bash +# Generate 3D template data +python3 fastsurfer_cc.py --subject_dir /data/subjects/sub001 \ + --slice_selection all \ + --save_template /data/templates/sub001 +``` + +This creates: +- `contours.txt`: Multi-slice contour data for 3D reconstruction +- `thickness_values.txt`: Thickness measurements across all slices +- `measurement_points.txt`: 3D vertex indices for thickness measurements + +**Benefits:** +- Enables volumetric thickness analysis +- Supports advanced 3D visualizations with proper surface topology +- Creates FreeSurfer-compatible overlay files for integration with other tools + +For visualization instructions and outputs, see the [cc_visualization.py documentation](../../scripts/cc_visualization.rst). + +### 2D Analysis + +When using `--slice_selection middle` or a specific slice number with `--save_template`: + +```bash +# Generate 2D template data (middle slice) +python3 fastsurfer_cc.py --subject_dir /data/subjects/sub001 \ + --slice_selection middle \ + --save_template /data/templates/sub001 +``` + +**Benefits:** +- Faster processing for single-slice analysis +- 2D visualization is most suitable for displaying downstream statistics +- Compatibility with classical corpus callosum studies + +For 2D visualization instructions and outputs, see the [cc_visualization.py documentation](../../scripts/cc_visualization.rst). + +### Choosing Analysis Mode + +**Use 3D Analysis (`--slice_selection all`) when:** +- You need complete volumetric analysis +- Surface-based visualization is required +- Integration with FreeSurfer workflows is needed +- Comprehensive thickness mapping across the entire corpus callosum is desired + +**Use 2D Analysis (`--slice_selection middle` or specific slice) when:** +- Traditional single-slice morphometry is sufficient +- Faster processing is preferred +- Focus is on mid-sagittal cross-sectional measurements +- Compatibility with classical corpus callosum studies is needed + +**Note:** The default behavior is `--slice_selection all` for comprehensive 3D analysis. Use `--slice_selection middle` to process only the middle slice for faster, traditional 2D analysis. + +## JSON Output Structure + +The pipeline generates two main JSON files with detailed measurements and analysis results: + +### `stats/callosum.CC.midslice.json` (Middle Slice Analysis) + +This file contains measurements from the middle sagittal slice and includes: + +**Shape Measurements (single values):** +- `total_area`: Total corpus callosum area (mm²) +- `total_perimeter`: Total perimeter length (mm) +- `circularity`: Shape circularity measure (4π × area / perimeter²) +- `cc_index`: Corpus callosum shape index (length/width ratio) +- `midline_length`: Length along the corpus callosum midline (mm) +- `curvature`: Average curve of the midline (degrees), measured by angle between it's sub-segements + +**Subdivisions** +- `areas`: Areas of CC using an improved Hofer-Frahm sub-division method (mm²). This gives more consistent sub-segemnts while preserving the original ratios. + +**Thickness Analysis:** +- `thickness`: Average corpus callosum thickness (mm) +- `thickness_profile`: Thickness profile (mm) of the corpus callosum slice (100 thickness values by default, listed from anterior to posterior CC ends) + + +**Volume Measurements (when multiple slices processed):** +- `cc_5mm_volume`: Total CC volume within 5mm slab using voxel counting (mm³) +- `cc_5mm_volume_pv_corrected`: Volume with partial volume correction using CC contours (mm³) + +**Anatomical Landmarks:** +- `ac_center`: Anterior commissure coordinates in original image space +- `pc_center`: Posterior commissure coordinates in original image space +- `ac_center_oriented_volume`: AC coordinates in standardized space (orient_volume.lta) +- `pc_center_oriented_volume`: PC coordinates in standardized space (orient_volume.lta) +- `ac_center_upright`: AC coordinates in upright space (cc_up.lta) +- `pc_center_upright`: PC coordinates in upright space (cc_up.lta) + +### `stats/callosum.CC.all_slices.json` (Multi-Slice Analysis) + +This file contains comprehensive per-slice analysis when using `--slice_selection all`: + +**Global Parameters:** +- `slices_in_segmentation`: Total number of slices in the segmentation volume +- `voxel_size`: Voxel dimensions [x, y, z] in mm +- `subdivision_method`: Method used for anatomical subdivision +- `num_thickness_points`: Number of points used for thickness estimation +- `subdivision_ratios`: Subdivision fractions used for regional analysis +- `contour_smoothing`: Gaussian sigma used for contour smoothing +- `slice_selection`: Slice selection mode used + +**Per-Slice Data (`slices` array):** + +Each slice entry contains the shape measurements, thickness analysis and sub-divisions as described above. diff --git a/doc/overview/modules/index.rst b/doc/overview/modules/index.rst new file mode 100644 index 000000000..17b1cc454 --- /dev/null +++ b/doc/overview/modules/index.rst @@ -0,0 +1,9 @@ +Modules +======= + +FastSurfer includes several specialized deep learning modules that can be run independently or as part of the main pipeline. These modules provide detailed sub-segmentations and analyses for specific brain regions. + +.. toctree:: + :maxdepth: 2 + + CC diff --git a/doc/scripts/advanced.rst b/doc/scripts/advanced.rst index 82551a7ca..d18d755dd 100644 --- a/doc/scripts/advanced.rst +++ b/doc/scripts/advanced.rst @@ -7,6 +7,8 @@ Advanced scripts fastsurfercnn cerebnet hypvinn + fastsurfer_cc + cc_visualization recon_surf segstats long_compat_segmentHA diff --git a/doc/scripts/cc_visualization.rst b/doc/scripts/cc_visualization.rst new file mode 100644 index 000000000..b280a953c --- /dev/null +++ b/doc/scripts/cc_visualization.rst @@ -0,0 +1,55 @@ +CorpusCallosum: cc_visualization.py +=================================== + +.. argparse:: + :module: CorpusCallosum.cc_visualization + :func: make_parser + :prog: cc_visualization.py + +Usage Examples +-------------- + +3D Visualization +~~~~~~~~~~~~~~~~ + +To visualize a 3D template generated by ``fastsurfer_cc.py`` (using ``--slice_selection all --save_template ...``), +point the script to the exported template directory: + +.. code-block:: bash + + python3 cc_visualization.py \ + --template_dir /data/templates/sub001/cc_template \ + --output_dir /data/visualizations/sub001 + +2D Visualization +~~~~~~~~~~~~~~~~ + +To visualize a 2D template (using ``--slice_selection middle --save_template ...``): + +.. code-block:: bash + + python3 cc_visualization.py \ + --template_dir /data/templates/sub001/cc_template \ + --output_dir /data/visualizations/sub001 \ + --twoD + +.. note:: + + You can still pass ``--contours``, ``--thickness`` and + ``--measurement_points`` directly when working with standalone files, but + ``--template_dir`` is the recommended way to load the multi-slice templates + produced by ``fastsurfer_cc.py``. + +Outputs +------- + +3D Mode Outputs (default): + - ``cc_mesh.vtk``: VTK format mesh file for 3D visualization + - ``cc_mesh.fssurf``: FreeSurfer surface format + - ``cc_mesh_overlay.curv``: FreeSurfer overlay file with thickness values + - ``cc_mesh.html``: Interactive 3D mesh visualization + - ``cc_mesh_snap.png``: Snapshot image of the 3D mesh + - ``midslice_2d.png``: 2D visualization of the middle slice + +2D Mode Outputs (when ``--twoD`` is specified): + - ``cc_thickness_2d.png``: 2D contour visualization with thickness colormap diff --git a/doc/scripts/fastsurfer_cc.rst b/doc/scripts/fastsurfer_cc.rst new file mode 100644 index 000000000..d2f5fcbcd --- /dev/null +++ b/doc/scripts/fastsurfer_cc.rst @@ -0,0 +1,21 @@ +CorpusCallosum: fastsurfer_cc.py +================================ + +.. note:: + We recommend to run FastSurfer-CC with the standard `run_fastsurfer.sh` interfaces ! + + +.. + [Note] To tell sphinx where in the documentation CorpusCallosum/README.md can be linked to, it needs to be included somewhere + +.. include:: ../../CorpusCallosum/README.md + :parser: fix_links.parser + :start-line: 1 + +.. argparse:: + :module: CorpusCallosum.fastsurfer_cc + :func: make_parser + :prog: fastsurfer_cc.py + +.. include:: ../overview/modules/CC.md + :parser: fix_links.parser diff --git a/env/fastsurfer.yml b/env/fastsurfer.yml index e3b991bf4..a3c5231de 100644 --- a/env/fastsurfer.yml +++ b/env/fastsurfer.yml @@ -5,8 +5,9 @@ channels: dependencies: - h5py==3.12.1 -- lapy==1.2.0 +- lapy==1.4.0 - matplotlib==3.10.1 +- monai==1.4.0 - nibabel==5.3.2 - numpy==1.26.4 - pandas==2.2.3 @@ -29,3 +30,6 @@ dependencies: - torch==2.6.0+cu126 - torchio==0.20.4 - torchvision==0.21.0+cu126 + - meshpy>=2025.1.1 + - pyrr>=0.10.3 + - whippersnappy>=1.3.1 diff --git a/pyproject.toml b/pyproject.toml index 438bf23c3..655b63816 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ classifiers = [ ] dependencies = [ 'h5py>=3.7', - 'lapy>=1.1.0', + "lapy>=1.4.0", 'matplotlib>=3.7.1', 'nibabel>=5.1.0', 'numpy>=1.25,<2', @@ -50,12 +50,17 @@ dependencies = [ 'torchvision>=0.15.2', 'tqdm>=4.65', 'yacs>=0.1.8', + 'monai>=1.4.0', + 'meshpy>=2025.1.1', + 'pyrr>=0.10.3', + 'pip>=25.0', ] [project.optional-dependencies] doc = [ 'furo!=2023.8.17', 'matplotlib', + 'whippersnappy>=1.3.1', 'memory-profiler', 'myst-parser', 'numpydoc', diff --git a/recon_surf/align_points.py b/recon_surf/align_points.py index e549466d4..c69446ee4 100755 --- a/recon_surf/align_points.py +++ b/recon_surf/align_points.py @@ -127,8 +127,7 @@ def find_rotation(p_mov: npt.NDArray, p_dst: npt.NDArray) -> np.ndarray: return R - -def find_rigid(p_mov: npt.NDArray, p_dst: npt.NDArray) -> np.ndarray: +def find_rigid(p_mov: npt.NDArray, p_dst: npt.NDArray, verbose: bool = False) -> npt.NDArray[float]: """ Find rigid transformation matrix between two point sets. @@ -138,10 +137,12 @@ def find_rigid(p_mov: npt.NDArray, p_dst: npt.NDArray) -> np.ndarray: Source points. p_dst : npt.NDArray Destination points. + verbose : bool, optional + Whether to print debug information, by default False. Returns ------- - T + np.ndarray Homogeneous transformation matrix. """ if p_mov.shape != p_dst.shape: @@ -159,16 +160,17 @@ def find_rigid(p_mov: npt.NDArray, p_dst: npt.NDArray) -> np.ndarray: t = centroid_dst.T - np.dot(R, centroid_mov.T) # homogeneous transformation m = p_mov.shape[1] - T = np.identity(m + 1) - T[:m, :m] = R - T[:m, m] = t + rigid_transform = np.identity(m + 1, dtype=float) + rigid_transform[:m, :m] = R + rigid_transform[:m, m] = t # compute disteances - dd = p_mov - p_dst - print(f"Initial avg SSD: {np.sum(dd * dd) / p_mov.shape[0]}") - dd = (np.transpose(R @ np.transpose(p_mov)) + t) - p_dst - print(f"Final avg SSD: {np.sum(dd * dd) / p_mov.shape[0]}") + if verbose: + dd = p_mov - p_dst + print(f"Initial avg SSD: {np.sum(dd * dd) / p_mov.shape[0]}") + dd = (np.transpose(R @ np.transpose(p_mov)) + t) - p_dst + print(f"Final avg SSD: {np.sum(dd * dd) / p_mov.shape[0]}") # return T, R, t - return T + return rigid_transform def find_affine(p_mov: npt.NDArray, p_dst: npt.NDArray) -> np.ndarray: """ diff --git a/recon_surf/align_seg.py b/recon_surf/align_seg.py index dd0d6c6c2..9332a6cdc 100755 --- a/recon_surf/align_seg.py +++ b/recon_surf/align_seg.py @@ -21,11 +21,12 @@ import align_points as align import image_io as iio -import lta as lta import numpy as np import SimpleITK as sitk from numpy import typing as npt +from FastSurferCNN.utils.lta import write_lta + HELPTEXT = """ Script to align two images based on the centroids of their segmentations @@ -397,7 +398,7 @@ def align_flipped(seg: sitk.Image, mid_slice: float | None = None) -> npt.NDArra # write transform lta print(f"writing: {options.outlta}") - lta.write_lta( + write_lta( options.outlta, T, options.srcseg, srcheader, options.trgseg, trgheader ) diff --git a/recon_surf/lta.py b/recon_surf/lta.py deleted file mode 100755 index c6011102c..000000000 --- a/recon_surf/lta.py +++ /dev/null @@ -1,103 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright 2021 Image Analysis Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from pathlib import Path - -import numpy.typing as npt - -# Collection of functions related to FreeSurfer's LTA (linear transform array) files: - -def write_lta( - filename: Path | str, - affine: npt.ArrayLike, - src_fname: Path | str, - src_header: dict, - dst_fname: Path | str, - dst_header: dict -) -> None: - """ - Write linear transform array info to an .lta file. - - Parameters - ---------- - filename : Path, str - File to write on. - affine : npt.ArrayLike - Linear transform array to be saved. - src_fname : Path, str - Source filename. - src_header : Dict - Source header. - dst_fname : Path, str - Destination filename. - dst_header : Dict - Destination header. - - Raises - ------ - ValueError - Header format missing field (Source or Destination). - """ - import getpass - from datetime import datetime - - fields = ("dims", "delta", "Mdc", "Pxyz_c") - for field in fields: - if field not in src_header: - raise ValueError(f"write_lta Error: src_header format missing field: {field}") - if field not in dst_header: - raise ValueError(f"write_lta Error: dst_header format missing field: {field}") - - src_dims = str(src_header["dims"][0:3]) - src_vsize = str(src_header["delta"][0:3]) - src_v2r = src_header["Mdc"] - src_c = src_header["Pxyz_c"] - - dst_dims = str(dst_header["dims"][0:3]) - dst_vsize = str(dst_header["delta"][0:3]) - dst_v2r = dst_header["Mdc"] - dst_c = dst_header["Pxyz_c"] - - f = open(filename, "w") - f.write( - (f"# transform file {filename}\n" - f"# created by {getpass.getuser()} on {datetime.now().ctime()}\n\n" - "type = 1 # LINEAR_RAS_TO_RAS\n" - "nxforms = 1\n" - "mean = 0.0 0.0 0.0\n" - "sigma = 1.0\n" - "1 4 4\n" - f"{affine}\n" - "src volume info\n" - "valid = 1 # volume info valid\n" - f"filename = {src_fname}\n" - f"volume = {src_dims}\n" - f"voxelsize = {src_vsize}\n" - f"xras = {src_v2r[0, :]}\n" - f"yras = {src_v2r[1, :]}\n" - f"zras = {src_v2r[2, :]}\n" - f"cras = {src_c}\n" - "dst volume info\n" - "valid = 1 # volume info valid\n" - f"filename = {dst_fname}\n" - f"volume = {dst_dims}\n" - f"voxelsize = {dst_vsize}\n" - f"xras = {dst_v2r[0, :]}\n" - f"yras = {dst_v2r[1, :]}\n" - f"zras = {dst_v2r[2, :]}\n" - f"cras = {dst_c}\n").replace("[", "").replace("]", "") - ) - f.close() diff --git a/recon_surf/paint_cc_into_pred.py b/recon_surf/paint_cc_into_pred.py deleted file mode 100644 index ec649a869..000000000 --- a/recon_surf/paint_cc_into_pred.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright 2019 Image Analysis Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# IMPORTS - -import argparse -import sys - -import nibabel as nib -import numpy as np -from numpy import typing as npt - -HELPTEXT = """ -Script to add corpus callosum segmentation (CC, FreeSurfer IDs 251-255) to -deep-learning prediction (e.g. aparc.DKTatlas+aseg.deep.mgz). - - -USAGE: -paint_cc_into_pred -in_cc -in_pred -out - - -Dependencies: - Python 3.8+ - - Nibabel to read and write FreeSurfer data - http://nipy.org/nibabel/ - -Original Author: Leonie Henschel -Date: Jul-10-2020 - -""" - - -def argument_parse(): - """ - Create a command line interface and return command line options. - - Returns - ------- - options : argparse.Namespace - Namespace object holding options. - """ - parser = argparse.ArgumentParser(usage=HELPTEXT) - parser.add_argument( - "--input_cc", - "-in_cc", - dest="input_cc", - help="path to input segmentation with Corpus Callosum (IDs 251-255 in FreeSurfer space)", - ) - parser.add_argument( - "--input_pred", - "-in_pred", - dest="input_pred", - help="path to input segmentation Corpus Callosum should be added to.", - ) - parser.add_argument( - "--output", - "-out", - dest="output", - help="path to output (input segmentation + added CC)", - ) - - args = parser.parse_args() - - if args.input_cc is None or args.input_pred is None or args.output is None: - sys.exit("ERROR: Please specify input and output segmentations") - - return args - - -def paint_in_cc(pred: npt.ArrayLike, aseg_cc: npt.ArrayLike) -> npt.ArrayLike: - """ - Paint corpus callosum segmentation into aseg+dkt segmentation map. - - Note, that this function modifies the original array and does not create a copy. - - Parameters - ---------- - asegdkt : npt.ArrayLike - Deep-learning segmentation map. - aseg_cc : npt.ArrayLike - Aseg segmentation with CC. - - Returns - ------- - asegdkt - Segmentation map with added CC. - """ - cc_mask = (aseg_cc >= 251) & (aseg_cc <= 255) - pred[cc_mask] = aseg_cc[cc_mask] - return pred - - -if __name__ == "__main__": - # Command Line options are error checking done here - options = argument_parse() - - print(f"Reading inputs: {options.input_cc} {options.input_pred}...") - aseg_image = np.asanyarray(nib.load(options.input_cc).dataobj) - prediction = nib.load(options.input_pred) - pred_with_cc = paint_in_cc(np.asanyarray(prediction.dataobj), aseg_image) - - print(f"Writing segmentation with corpus callosum to: {options.output}") - pred_with_cc_fin = nib.MGHImage(pred_with_cc, prediction.affine, prediction.header) - pred_with_cc_fin.to_filename(options.output) - - sys.exit(0) - - -# TODO: Rename the file (paint_cc_into_asegdkt or similar) and functions. diff --git a/recon_surf/recon-surf.sh b/recon_surf/recon-surf.sh index 648064710..5727c8505 100755 --- a/recon_surf/recon-surf.sh +++ b/recon_surf/recon-surf.sh @@ -619,21 +619,52 @@ fi # ============================= CC SEGMENTATION ============================================ -{ - echo " " - echo "============ Creating and adding CC Segmentation ============" - echo " " -} | tee -a "$LF" -# create aseg.auto including corpus callosum segmentation and 46 sec, requires norm.mgz -# Note: if original input segmentation already contains CC, this will exit with ERROR -# in the future maybe check and skip this step (and next) -cmd="mri_cc -aseg $aseg_nocc -o aseg.auto.mgz -lta $mdir/transforms/cc_up.lta $subject" -RunIt "$cmd" "$LF" -# add CC into aparc.DKTatlas+aseg.deep (not sure if this is really needed) -cmd="$python ${binpath}paint_cc_into_pred.py -in_cc $mdir/aseg.auto.mgz -in_pred $asegdkt_segfile -out $mdir/aparc.DKTatlas+aseg.deep.withCC.mgz" -RunIt "$cmd" "$LF" - - +# here, we are only generating the "necessary" files for the pipeline to recon-surf pipeline to +# complete, people should use the seg pipeline to get extended results. +callosum_seg="callosum_seg_aseg_space.mgz" +callosum_seg_manedit="$(add_file_suffix "$callosum_seg" "manedit")" +aseg_auto="aseg.auto.mgz" +CorpusCallosumDir="$FASTSURFER_HOME/CorpusCallosum" +updated_cc_seg=0 +if [[ ! -e "$mdir/$aseg_auto" ]] || [[ ! -e "$mdir/$callosum_seg" ]] || [[ "$edits" == 1 ]] +then + { + echo " " + echo "============ Creating and adding CC Segmentation ============" + echo " " + } | tee -a "$LF" +fi +# here, in edits mode we also check, if the corpus callosum should be updated based on an updated aseg.nocc +if [[ ! -e "$mdir/$callosum_seg" ]] || \ + { [[ "$edits" == 1 ]] && [[ "$(date -r "$mdir/$aseg_nocc" "+%s")" -gt "$(date -r "$mdir/$callosum_seg" "+%s")" ]] ; } +then + { + echo "Segmenting the corpus callosum, so mri/$aseg_nocc exists. If you are interested in detailed" + echo " and extended analysis and statistics of the Corpus Callosum, use the corpus callosum pipeline" + echo " of the segmentation pipeline (in run_fastsurfer.sh, i.e. run without --no_cc)." + } + updated_cc_seg=1 + # create aseg.auto including corpus callosum segmentation and 46 sec, requires norm.mgz + # Note: if original input segmentation already contains CC, this will exit with ERROR + # in the future maybe check and skip this step (and next) + cmda=($python "$CorpusCallosumDir/fastsurfer_cc.py" --sd "$SUBJECTS_DIR" --sid "$subject" + "--aseg_name" "$mdir/$aseg_nocc" "--segmentation_in_orig" "$mdir/$callosum_seg" + --threads "$threads" + # qc_snapshots are only defined by the seg_only pipeline + # limit the processing things to do here + --slice_selection "middle" --cc_measures "none" --cc_mid_measures "none" --surf "none" + --thickness_overlay "none") + run_it "$LF" "${cmda[@]}" +fi +# do not move below statement up, fastsurfer_cc.py uses the $callosum_seg variable +if [[ "$edits" == 1 ]] && [[ -e "$mdir/$callosum_seg_manedit" ]] ; then callosum_seg="$callosum_seg_manedit" ; fi +cmd_paint_cc_into_pred=($python "$CorpusCallosumDir/paint_cc_into_pred.py" -in_cc "$mdir/$callosum_seg" -in_pred) +if [[ ! -e "$mdir/$aseg_auto" ]] || [[ "$updated_cc_seg" == 1 ]] +then + # add CC into aseg.auto.mgz as mri_cc did before. Not sure where this is used. + cmda=("${cmd_paint_cc_into_pred[@]}" "$mdir/$aseg_nocc" "-out" "$mdir/$aseg_auto") + run_it "$LF" "${cmda[@]}" +fi # ============================= FILLED ===================================================== { diff --git a/requirements.mac.txt b/requirements.mac.txt index 95af69a75..5f8775c61 100644 --- a/requirements.mac.txt +++ b/requirements.mac.txt @@ -16,4 +16,8 @@ torchio>=0.18.83 torchvision>=0.15.2 tqdm>=4.65 yacs>=0.1.8 - +monai>=1.4.0 +meshpy>=2025.1.1 +pyrr>=0.10.3 +whippersnappy>=1.3.1 +pip>=25.0 \ No newline at end of file diff --git a/run_fastsurfer.sh b/run_fastsurfer.sh index 83505e912..39d5e7ac2 100755 --- a/run_fastsurfer.sh +++ b/run_fastsurfer.sh @@ -32,6 +32,7 @@ fastsurfercnndir="$FASTSURFER_HOME/FastSurferCNN" cerebnetdir="$FASTSURFER_HOME/CerebNet" hypvinndir="$FASTSURFER_HOME/HypVINN" reconsurfdir="$FASTSURFER_HOME/recon_surf" +CorpusCallosumDir="$FASTSURFER_HOME/CorpusCallosum" # Regular flags defaults subject="" @@ -49,6 +50,7 @@ hypo_segfile="" hypo_statsfile="" hypvinn_flags=() hypvinn_regmode="coreg" +cc_flags=() conformed_name="" conformed_name_t2="" norm_name="" @@ -70,6 +72,7 @@ native_image="false" run_asegdkt_module="1" run_cereb_module="1" run_hypvinn_module="1" +run_cc_module="1" threads_seg="1" threads_surf="1" # python3.10 -s excludes user-directory package inclusion @@ -213,6 +216,11 @@ SEGMENTATION PIPELINE: --no_biasfield Deactivate the calculation of partial volume-corrected statistics. + CORPUS CALLOSUM MODULE: + --no_cc Skip the segmentation and analysis of the corpus callosum. + --qc_snap Create QC snapshots in \$SUBJECTS_DIR/\$sid/qc_snapshots + to simplify the QC process. + HYPOTHALAMUS MODULE (HypVINN): --no_hypothal Skip the hypothalamus segmentation. --no_biasfield This option implies --no_hypothal, as the hypothalamus @@ -458,6 +466,12 @@ case $key in --mask_name) mask_name="$1" ; warn_seg_only+=("$key" "$1") ; warn_base+=("$key" "$1") ; shift ;; --merged_segfile) merged_segfile="$1" ; shift ;; + # corupus callosum module options + #============================================================= + --no_cc) run_cc_module="0" ;; + # TODO: remove this dev flag + --upright) cc_flags+=("--upright_volume" "mri/upright.mgz") ;; + # cereb module options #============================================================= --no_cereb) run_cereb_module="0" ;; @@ -480,7 +494,11 @@ case $key in ;; # several options that set a variable - --qc_snap) hypvinn_flags+=(--qc_snap) ;; + --qc_snap) + hypvinn_flags+=(--qc_snap) ; + cc_flags+=(--qc_image "qc_snapshots/callosum.png" --thickness_image "qc_snapshots/callosum.thickness.png" + --cc_html "qc_snapshots/corpus_callosum.html") + ;; ############################################################## # surf-pipeline options @@ -588,6 +606,8 @@ fi if [[ -z "$merged_segfile" ]] ; then merged_segfile="$subject_dir/mri/fastsurfer.merged.mgz" ; fi if [[ -z "$asegdkt_segfile" ]] ; then asegdkt_segfile="$subject_dir/mri/aparc.DKTatlas+aseg.deep.mgz" ; fi if [[ -z "$aseg_segfile" ]] ; then aseg_segfile="$subject_dir/mri/aseg.auto_noCCseg.mgz"; fi +if [[ -z "$aseg_auto_segfile" ]] ; then aseg_auto_segfile="$subject_dir/mri/aseg.auto.mgz"; fi +if [[ -z "$callosum_seg" ]] ; then callosum_seg="$subject_dir/mri/callosum.CC.orig.mgz"; fi if [[ -z "$asegdkt_statsfile" ]] ; then asegdkt_statsfile="$subject_dir/stats/aseg+DKT.stats" ; fi if [[ -z "$asegdkt_vinn_statsfile" ]] ; then asegdkt_vinn_statsfile="$subject_dir/stats/aseg+DKT.VINN.stats" ; fi if [[ -z "$aseg_vinn_statsfile" ]] ; then aseg_vinn_statsfile="$subject_dir/stats/aseg.VINN.stats" ; fi @@ -708,6 +728,18 @@ then fi fi +if [[ "$run_seg_pipeline" == "1" ]] && { [[ "$run_asegdkt_module" == "0" ]] && [[ "$run_cc_module" == "1" ]]; } +then + if [[ ! -f "$asegdkt_segfile" ]] + then + echo "ERROR: To run the corpus callosum module but no asegdkt, the aseg segmentation must already exist." + echo " You passed --no_asegdkt but the asegdkt segmentation ($asegdkt_segfile) could not be found." + echo " If the segmentation is not saved in the default location ($asegdkt_segfile_default)," + echo " specify the absolute path and name via --asegdkt_segfile" + exit 1 + fi +fi + if [[ "$run_surf_pipeline" == "1" ]] && [[ "$native_image" != "false" ]] then echo "ERROR: The surface pipeline is not compatible with the options --native_image or " @@ -1078,6 +1110,88 @@ then fi fi + if [[ "$run_cc_module" ]] + then + # ============================= CC SEGMENTATION ============================================ + + # generate file names of for the analysis + asegdkt_withcc_segfile="$(add_file_suffix "$asegdkt_segfile" "withCC")" + asegdkt_withcc_vinn_statsfile="$(add_file_suffix "$asegdkt_vinn_statsfile" "withCC")" + aseg_auto_statsfile="$(dirname "$aseg_vinn_statsfile")/aseg.auto.mgz" + # note: callosum manedit currently only affects inpainting and not internal FastSurferCC processing (surfaces etc) + callosum_seg_manedit="$(add_file_suffix "$callosum_seg" "manedit")" + # generate callosum segmentation, mesh, shape and downstream measure files + cmd=($python "$CorpusCallosumDir/fastsurfer_cc.py" --sd "$sd" --sid "$subject" --threads "$threads_seg" + "--aseg_name" "$asegdkt_segfile" "--segmentation_in_orig" "$callosum_seg" "${cc_flags[@]}") + { + echo_quoted "${cmd[@]}" + "${cmd[@]}" + if [[ "${PIPESTATUS[0]}" != 0 ]] ; then echo "ERROR: FastSurferCC corpus callosum analysis failed!" ; exit 1 ; fi + if [[ "$edits" == 1 ]] && [[ -f "$callosum_seg_manedit" ]] ; then callosum_seg="$callosum_seg_manedit" ; fi + + # add CC into aparc.DKTatlas+aseg.deep.mgz and aseg.auto.mgz as mri_cc did before. + cmd=($python "$CorpusCallosumDir/paint_cc_into_pred.py" -in_cc "$callosum_seg" -in_pred "$asegdkt_segfile" + "-out" "$asegdkt_withcc_segfile" "-aseg" "$aseg_auto_segfile") + echo_quoted "${cmd[@]}" + "${cmd[@]}" + if [[ "${PIPESTATUS[0]}" != 0 ]] ; then echo "ERROR: asegdkt cc inpainting failed!" ; exit 1 ; fi + + if [[ "$run_biasfield" == 1 ]] + then + cmd=($python "${fastsurfercnndir}/segstats.py" --segfile "$asegdkt_withcc_segfile" --normfile "$norm_name" + --lut "$fastsurfercnndir/config/FreeSurferColorLUT.txt" --sd "${sd}" --sid "${subject}" + --ids 2 4 5 7 8 10 11 12 13 14 15 16 17 18 24 26 28 31 41 43 44 46 47 49 50 51 52 53 + 54 58 60 63 77 251 252 253 254 255 + 1002 1003 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 + 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1034 1035 + 2002 2003 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 + 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028 2029 2030 2031 2034 2035 + --threads "$threads_seg" --empty --excludeid 0 + --segstatsfile "$asegdkt_withcc_vinn_statsfile" + measures + # the following measures are unaffected by CC and do not need to be recomputed + --import SubCortGray Mask + ) + if [[ "$run_talairach_registration" == "true" ]] + then + cmd+=("EstimatedTotalIntraCranialVol" "BrainSegVol-to-eTIV" "MaskVol-to-eTIV") + fi + cmd+=(--file "$asegdkt_vinn_statsfile" + # recompute the measures changes coming from CC inpainting (only SubCortGray does not change) + --compute BrainSeg BrainSegNotVent SupraTentorial SupraTentorialNotVent + rhCerebralWhiteMatter lhCerebralWhiteMatter CerebralWhiteMatter + ) + echo_quoted "${cmd[@]}" + "${cmd[@]}" + if [[ "${PIPESTATUS[0]}" != 0 ]] ; then + echo "ERROR: asegdkt statsfile ($asegdkt_withcc_segfile) generation failed!" ; exit 1 + # this will only terminate the subshell + fi + fi + } 2>&1 | tee -a "$seg_log" + code="${PIPESTATUS[0]}" + if [[ "$code" != 0 ]]; then exit 1; fi # forward subshell exit to main script + + if [[ "$run_biasfield" == 1 ]] + then + { + cmd=($python "${fastsurfercnndir}/segstats.py" --segfile "$aseg_auto_segfile" --normfile "$norm_name" + --lut "$fastsurfercnndir/config/FreeSurferColorLUT.txt" --sd "${sd}" --sid "${subject}" + --threads "$threads_seg" --empty --excludeid 0 + --ids 2 4 3 5 7 8 10 11 12 13 14 15 16 17 18 24 26 28 31 41 42 43 44 46 47 49 50 51 52 53 54 58 60 63 77 + 251 252 253 254 255 + --segstatsfile "$aseg_auto_statsfile" + measures --import "all" --file "$asegdkt_withcc_vinn_statsfile" + ) + echo_quoted "${cmd[@]}" + "${cmd[@]}" 2>&1 + if [[ "${PIPESTATUS[0]}" != 0 ]] ; then echo "ERROR: aseg statsfile ($aseg_auto_segfile) failed!" ; exit 1 ; fi + } | tee -a "$seg_log" + if [[ "${PIPESTATUS[0]}" != 0 ]] ; then exit 1; fi # forward subshell exit to main script + + fi + fi + if [[ "$run_cereb_module" == "1" ]] then if [[ "$run_biasfield" == "1" ]] diff --git a/tools/export_pip-r.sh b/tools/export_pip-r.sh old mode 100644 new mode 100755