-
Notifications
You must be signed in to change notification settings - Fork 367
(UNETR) : Add predict label function and custom dataloader which can train with own data . #420
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
WalkthroughAdds BTCV dataset support with MONAI-based training/validation and prediction loaders, introduces new dataset/config constants, changes UNETR positional-embedding defaults and validation, updates training/inference flows (sliding-window inference, prediction saving), adjusts transforms/utilities and dependencies, and refreshes BTCV README and CLI examples. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor User
participant Main as main.py
participant Loader as get_loader / getDatasetLoader
participant Trainer as trainer.py
participant Model as UNETR
User->>Main: python main.py [--btcv ...]
alt args.btcv == True
Main->>Loader: getDatasetLoader(args)
else
Main->>Loader: get_loader(args)
end
Loader-->>Main: train_loader, val_loader
Main->>Model: initialize UNETR (pos_embed='learnable'|'sincos'|'none')
Main->>Trainer: train(train_loader, val_loader, Model, args)
note over Trainer,Model: Validation cadence reduced (val_every lowered)
sequenceDiagram
autonumber
actor User
participant Test as test.py
participant Dset as getDatasetLoader/getPredictLoader
participant Model as UNETR
participant SWI as sliding_window_inference
participant IO as SaveImaged
User->>Test: python test.py --mode validation|predict
alt mode == validation
Test->>Dset: getDatasetLoader(args)
Dset-->>Test: val_loader
Test->>Model: load checkpoint
loop batches
Test->>SWI: inference(inputs, model, args)
SWI-->>Test: logits
Test->>Test: softmax -> argmax -> labels -> Dice
end
Test-->>User: Mean Dice results
else mode == predict
Test->>Dset: getPredictLoader(args)
Dset-->>Test: pred_loader, preTransform
Test->>Model: load checkpoint
loop items
Test->>SWI: inference(image, model, args)
SWI-->>Test: logits
Test->>IO: Invertd -> AsDiscreted -> SaveImaged
end
Test-->>User: Saved segmentations
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
CodeRabbit Configuration File (
|
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 7
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
UNETR/BTCV/networks/unetr.py (1)
88-99
: Guard ViT instantiation onpos_embed_type
MONAI renamed the ViT kwarg frompos_embed
(available in MONAI 1.0.1) topos_embed_type
(introduced in MONAI 1.3.0) (docs.monai.io). Use introspection oninspect.signature(ViT).parameters
to passpos_embed_type
when available, otherwise fallback topos_embed
.
Apply:- self.vit = ViT( - in_channels=in_channels, - img_size=img_size, - patch_size=self.patch_size, - hidden_size=hidden_size, - mlp_dim=mlp_dim, - num_layers=self.num_layers, - num_heads=num_heads, - pos_embed_type=pos_embed, - classification=self.classification, - dropout_rate=dropout_rate, - ) + import inspect + vit_args = dict( + in_channels=in_channels, + img_size=img_size, + patch_size=self.patch_size, + hidden_size=hidden_size, + mlp_dim=mlp_dim, + num_layers=self.num_layers, + num_heads=num_heads, + classification=self.classification, + dropout_rate=dropout_rate, + ) + if "pos_embed_type" in inspect.signature(ViT).parameters: + vit_args["pos_embed_type"] = pos_embed + else: + vit_args["pos_embed"] = pos_embed + self.vit = ViT(**vit_args)UNETR/BTCV/README.md (1)
19-33
: Fix UNETR init snippet:pos_embed
value outdated.The code now supports ['sincos','learnable','none']; “perceptron” will raise.
Apply:
- pos_embed='perceptron', + pos_embed='learnable',
🧹 Nitpick comments (15)
UNETR/BTCV/config.py (1)
1-3
: Avoid hard-coded data roots; make them configurable via env vars.This improves portability and lets users point to their own data without editing code.
Apply this diff (and add the import shown below) to allow overrides:
-NIFTI_DATA_ROOT = 'data/images' # nifti image directory -NIFTI_LABEL_ROOT = 'data/labels' # nifti label directory -PREDICT_DATA_ROOT = 'data/predict' # predict image directory +NIFTI_DATA_ROOT = os.getenv("NIFTI_DATA_ROOT", "data/images") # nifti image directory +NIFTI_LABEL_ROOT = os.getenv("NIFTI_LABEL_ROOT", "data/labels") # nifti label directory +PREDICT_DATA_ROOT = os.getenv("PREDICT_DATA_ROOT", "data/predict") # predict image directoryOutside the selected lines, add:
import osUNETR/BTCV/requirements.txt (1)
5-5
: Normalize package name casing.The import is
import tensorboardX
but requirement liststensorboardx
. Pip is case-insensitive, but for consistency considertensorboardX==2.6.4
.UNETR/BTCV/utils/data_utils.py (1)
16-17
: Remove unused imports and constants.
Path
,sitk
,WORKROOT
, andJPG_EXT
aren’t used here.Apply:
-from pathlib import Path -import SimpleITK as sitk + # (removed unused imports) -WORKROOT = Path(__file__).parent.parent -JPG_EXT = '.jpg' + # (removed unused constants)Also applies to: 22-23
UNETR/BTCV/networks/unetr.py (1)
61-66
: Update examples to match new default/allowed values.Replace
pos_embed='perceptron'
with'learnable'
and drop'conv'
.Apply:
- # for 4-channel input 3-channel output with patch size of (128,128,128), conv position embedding and instance norm - >>> net = UNETR(in_channels=4, out_channels=3, img_size=(128,128,128), pos_embed='conv', norm_name='instance') + # for 4-channel input 3-channel output with patch size of (128,128,128), learnable position embedding and instance norm + >>> net = UNETR(in_channels=4, out_channels=3, img_size=(128,128,128), pos_embed='learnable', norm_name='instance')UNETR/BTCV/README.md (6)
42-53
: Make CLI block copy-pasteable.Use line continuations for multi-line shell commands.
Apply:
-```bash -python main.py ---feature_size=32 ---batch_size=1 ---logdir=unetr_test ---fold=0 ---optim_lr=1e-4 ---lrschedule=warmup_cosine ---infer_overlap=0.5 ---save_checkpoint ---data_dir=/dataset/dataset0/ -``` +```bash +python main.py \ + --feature_size=32 \ + --batch_size=1 \ + --logdir=unetr_test \ + --fold=0 \ + --optim_lr=1e-4 \ + --lrschedule=warmup_cosine \ + --infer_overlap=0.5 \ + --save_checkpoint \ + --data_dir=/dataset/dataset0/ +```
11-13
: Use correct code fences.These blocks are Python, not bash.
Apply:
-```bash +```pythonAlso applies to: 19-33
130-147
: Clarify predict mode data location.Mention PREDICT_DATA_ROOT (config or env var) so users know where to place images.
Proposed addition after the predict command: “Place images under PREDICT_DATA_ROOT (default: data/predict) or set the env var PREDICT_DATA_ROOT.”
7-7
: Fix heading level increment (MD001).“Installing Dependencies” should be H2 to follow H1.
Apply:
-### Installing Dependencies +## Installing Dependencies
171-171
: Typo: “Left Kideny” → “Left Kidney”.-- Target: 13 abdominal organs including 1. Spleen 2. Right Kidney 3. Left Kideny 4.Gallbladder ... +- Target: 13 abdominal organs including 1. Spleen 2. Right Kidney 3. Left Kidney 4. Gallbladder ...
70-70
: Avoid bare URLs; make them clickable text.Improves readability and satisfies MD034.
Example:
-https://developer.download.nvidia.com/assets/Clara/monai/research/UNETR_model_best_acc.pth +[Pretrained UNETR checkpoint (.pth)](https://developer.download.nvidia.com/assets/Clara/monai/research/UNETR_model_best_acc.pth)Apply similarly to the TorchScript link and tutorial links.
Also applies to: 97-97, 157-157, 161-161
UNETR/BTCV/main.py (1)
46-46
: Flag default True on --save_checkpoint; cannot disable.Using action="store_true", default=True prevents users from turning it off. Prefer BooleanOptionalAction or an explicit --no-save-checkpoint.
-parser.add_argument("--save_checkpoint", action="store_true", default=True, help="save checkpoint during training") +parser.add_argument( + "--save_checkpoint", + action=argparse.BooleanOptionalAction, + default=True, + help="save checkpoint during training", +)UNETR/BTCV/test.py (1)
109-121
: Validation loop only measures first item if batch_size>1.Indexing [0] ignores the rest of the batch. Iterate per-item to support batch>1.
- for batch, label in loader: - val_inputs, val_labels = (batch.cuda(), label.cuda()) - val_outputs = inference(val_inputs, model, args) - val_labels = val_labels.cpu().numpy()[:, 0, :, :, :] - dice_list_sub = [] - for i in range(1, args.out_channels): - every_Dice = dice(val_outputs[0] == i, val_labels[0] == i) - dice_list_sub.append(every_Dice) - mean_dice = np.mean(dice_list_sub) + for batch, label in loader: + val_inputs, val_labels = (batch.cuda(), label.cuda()) + val_outputs = inference(val_inputs, model, args) # (B, H, W, D) + val_labels = val_labels.cpu().numpy()[:, 0, ...] # (B, H, W, D) + batch_dice = [] + for b in range(val_outputs.shape[0]): + dice_list_sub = [] + for i in range(1, args.out_channels): + dice_list_sub.append(dice(val_outputs[b] == i, val_labels[b] == i)) + batch_dice.append(np.mean(dice_list_sub)) + mean_dice = float(np.mean(batch_dice)) print("Mean Dice: {}".format(mean_dice)) dice_list_case.append(mean_dice)UNETR/BTCV/dataset/customDataset.py (3)
32-34
: Deterministic pairing and basic filtering.Sort file lists and optionally filter valid extensions to keep image/label alignment stable across runs.
- dataName = [d for d in os.listdir(NIFTI_LABEL_ROOT)] + dataName = sorted(d for d in os.listdir(NIFTI_LABEL_ROOT) if d.endswith((".nii", ".nii.gz")))
86-88
: DataLoader perf knobs.Consider pin_memory and persistent_workers for faster host→GPU transfer.
- trainLoader = DataLoader(trainDataset,batch_size=args.batch_size,shuffle=True,num_workers=args.workers, collate_fn=_get_collate_fn(isTrain=True)) - valLoader = DataLoader(valDataset,batch_size=args.batch_size,shuffle=False,num_workers=args.workers, collate_fn=_get_collate_fn(isTrain=False)) + trainLoader = DataLoader( + trainDataset, + batch_size=args.batch_size, + shuffle=True, + num_workers=args.workers, + pin_memory=True, + persistent_workers=True, + collate_fn=_get_collate_fn(isTrain=True), + ) + valLoader = DataLoader( + valDataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.workers, + pin_memory=True, + persistent_workers=True, + collate_fn=_get_collate_fn(isTrain=False), + )
92-97
: Rename ambiguous parameter l.Avoid single-letter “l” (E741). Improves readability.
-def _splitList(l, trainRatio:float = 0.8): - totalNum = len(l) +def _splitList(items, trainRatio: float = 0.8): + totalNum = len(items) splitIdx = int(totalNum * trainRatio) - return l[:splitIdx], l[splitIdx :] + return items[:splitIdx], items[splitIdx:]
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (9)
UNETR/BTCV/README.md
(7 hunks)UNETR/BTCV/config.py
(1 hunks)UNETR/BTCV/dataset/customDataset.py
(1 hunks)UNETR/BTCV/main.py
(4 hunks)UNETR/BTCV/networks/unetr.py
(2 hunks)UNETR/BTCV/requirements.txt
(1 hunks)UNETR/BTCV/test.py
(3 hunks)UNETR/BTCV/trainer.py
(1 hunks)UNETR/BTCV/utils/data_utils.py
(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
UNETR/BTCV/test.py (3)
UNETR/BTCV/networks/unetr.py (1)
UNETR
(22-230)UNETR/BTCV/trainer.py (1)
dice
(27-33)UNETR/BTCV/dataset/customDataset.py (2)
getDatasetLoader
(31-90)getPredictLoader
(98-120)
UNETR/BTCV/main.py (5)
UNETR/BTCV/networks/unetr.py (1)
UNETR
(22-230)UNETR/BTCV/optimizers/lr_scheduler.py (1)
LinearWarmupCosineAnnealingLR
(92-172)UNETR/BTCV/trainer.py (1)
run_training
(150-233)UNETR/BTCV/utils/data_utils.py (1)
get_loader
(72-168)UNETR/BTCV/dataset/customDataset.py (1)
getDatasetLoader
(31-90)
🪛 Ruff (0.12.2)
UNETR/BTCV/dataset/customDataset.py
92-92: Ambiguous variable name: l
(E741)
🪛 markdownlint-cli2 (0.17.2)
UNETR/BTCV/README.md
7-7: Heading levels should only increment by one level at a time
Expected: h2; Actual: h3
(MD001, heading-increment)
70-70: Bare URL used
(MD034, no-bare-urls)
157-157: Bare URL used
(MD034, no-bare-urls)
161-161: Bare URL used
(MD034, no-bare-urls)
🪛 LanguageTool
UNETR/BTCV/README.md
[grammar] ~36-~36: There might be a mistake here.
Context: ...rlapping patches of size (16, 16, 16)
. The position embedding is performed usin...
(QB_NEW_EN)
[grammar] ~37-~37: There might be a mistake here.
Context: ...d hyper-parameters as introduced in [2]. The decoder uses convolutional and resid...
(QB_NEW_EN)
[grammar] ~40-~40: There might be a mistake here.
Context: ...ommand can be used to initiate training using PyTorch native AMP package: ```bash py...
(QB_NEW_EN)
[style] ~59-~59: You have already used this phrasing in nearby sentences. Consider replacing it to add variety to your writing.
Context: ...ing command. To disable AMP, --noamp
needs to be added to the training command. If U...
(REP_NEED_TO_VB)
[grammar] ~61-~61: There might be a mistake here.
Context: ...ng the learning rate (i.e. --optim_lr
) according to the number of GPUs. For ins...
(QB_NEW_EN)
[grammar] ~66-~66: There might be a mistake here.
Context: ...kpoints and TorchScript models of UNETR using BTCV dataset. For using the pre-traine...
(QB_NEW_EN)
[uncategorized] ~72-~72: Do not mix variants of the same word (‘pretrain’ and ‘pre-train’) within a single text.
Context: ...point in the following directory or use --pretrained_dir
to provide the address of where th...
(EN_WORD_COHERENCY)
[uncategorized] ~97-~97: Do not mix variants of the same word (‘pretrain’ and ‘pre-train’) within a single text.
Context: ...model in the following directory or use --pretrained_dir
to provide the address of where th...
(EN_WORD_COHERENCY)
[grammar] ~125-~125: There might be a mistake here.
Context: ...ript model in the following directory or use --pretrained_dir
to provide the ad...
(QB_NEW_EN)
[uncategorized] ~126-~126: Do not mix variants of the same word (‘pretrain’ and ‘pre-train’) within a single text.
Context: ...model in the following directory or use --pretrained_dir
to provide the address of where th...
(EN_WORD_COHERENCY)
[uncategorized] ~151-~151: Do not mix variants of the same word (‘pretrain’ and ‘pre-train’) within a single text.
Context: ...nce time. If you would like to use the pretrained TorchScript model, `--saved_checkpoint=...
(EN_WORD_COHERENCY)
[grammar] ~155-~155: There might be a mistake here.
Context: ...or the task of multi-organ segmentation using BTCV dataset can be found in the follow...
(QB_NEW_EN)
🔇 Additional comments (5)
UNETR/BTCV/utils/data_utils.py (2)
78-79
: Channel handling change looks good; confirm label shape assumptions.Switching to
EnsureChannelFirstd(..., channel_dim="no_channel")
is fine. Ensure downstream expects labels with shape [1, D, H, W].Also applies to: 109-110
77-79
: Remove ITKReader override: nibabel is still declared in DAE/BTCV_Finetune/requirements.txt, so the default loader will handle .nii/.nii.gz correctly.UNETR/BTCV/main.py (2)
123-124
: Mixed batch formats between BTCV and custom loaders — ensure val_epoch fix is applied.With args.btcv True, batches are dicts; with custom, they’re [images, labels]. This relies on trainer.val_epoch handling both. Please apply the trainer.py fix accordingly.
162-164
: Post-processing aligns with MONAI metrics.AsDiscrete(to_onehot=out_channels) and (argmax=True, to_onehot=out_channels) are appropriate for DiceMetric with include_background=True.
UNETR/BTCV/test.py (1)
101-103
: torch.load keyword compatibility.weights_only was introduced in newer PyTorch. Your call uses weights_only=False, which can raise TypeError on older versions. Guard it for wider compatibility.
- model_dict = torch.load(pretrained_pth, weights_only=False) + try: + model_dict = torch.load(pretrained_pth, weights_only=False) + except TypeError: + model_dict = torch.load(pretrained_pth)
- `dataset/customDataset.py` - Ensure data type of image/label match float/int required by CE/one-hot - Add transforms.EnsureTyped in preTransform pipeline - `networks/unetr.py`, `README.md` - Change default param of pos_embe to `learnable` - `requirements.txt` - Assign the range of torch version and add nibabel version - `test.py` - Refactor the code - `trainer.py` - Handle unpack data process by loader
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
UNETR/BTCV/requirements.txt (1)
1-7
: Pins look sane; past concerns addressed.Good: MONAI 1.5.0, PyTorch constrained, NiBabel present. This addresses the prior “pin torch + add NiFTI reader” feedback.
🧹 Nitpick comments (14)
UNETR/BTCV/requirements.txt (1)
3-3
: Prefer headless OpenCV for server/CI environments.Switch to opencv-python-headless to avoid unnecessary GUI deps and X11 issues.
-opencv_python +opencv-python-headlessUNETR/BTCV/dataset/customDataset.py (4)
92-97
: Rename ambiguous variable and sort for deterministic split.Avoid “l”; sort for reproducibility instead of relying on filesystem order.
-def _splitList(l, trainRatio:float = 0.8): - totalNum = len(l) - splitIdx = int(totalNum * trainRatio) - - return l[:splitIdx], l[splitIdx :] +def _splitList(items, trainRatio: float = 0.8): + items = sorted(items) + total = len(items) + split = int(total * trainRatio) + return items[:split], items[split:]
86-87
: Enable pinned memory (and optional prefetch) for faster GPU transfer.Dataloaders feeding CUDA benefit from pin_memory; prefetch_factor helps when num_workers > 0.
-trainLoader = DataLoader(trainDataset,batch_size=args.batch_size,shuffle=True,num_workers=args.workers, collate_fn=_get_collate_fn(isTrain=True)) -valLoader = DataLoader(valDataset,batch_size=args.batch_size,shuffle=False,num_workers=args.workers, collate_fn=_get_collate_fn(isTrain=False)) +trainLoader = DataLoader( + trainDataset, + batch_size=args.batch_size, + shuffle=True, + num_workers=args.workers, + pin_memory=True, + prefetch_factor=2 if args.workers and args.workers > 0 else None, + collate_fn=_get_collate_fn(isTrain=True), +) +valLoader = DataLoader( + valDataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.workers, + pin_memory=True, + prefetch_factor=2 if args.workers and args.workers > 0 else None, + collate_fn=_get_collate_fn(isTrain=False), +)
48-57
: Expose num_samples via args to avoid hard-coding.Let users tune crop density without editing code.
- transforms.RandCropByPosNegLabeld( + transforms.RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=(args.roi_x, args.roi_y, args.roi_z), pos=1, neg=1, - num_samples=4, + num_samples=getattr(args, "num_samples", 4), image_key="image", image_threshold=0, ),
9-29
: Consider using MONAI’s list_data_collate to simplify collate_fn.list_data_collate flattens RandCropByPosNegLabeld outputs and stacks tensors. If you keep custom collate, returning a tuple (images, labels) is more idiomatic than a list.
- return [images.float(), labels.long()] + return images.float(), labels.long()Alternative:
from monai.data import list_data_collate # ... trainLoader = DataLoader(trainDataset, ..., collate_fn=list_data_collate) valLoader = DataLoader(valDataset, ..., collate_fn=list_data_collate)UNETR/BTCV/README.md (9)
8-8
: Fix heading level (MD001).Use H2 after an H1.
-### Installing Dependencies +## Installing Dependencies
28-29
: Align text with the new default pos_embed=‘learnable’.Below you still say “position embedding is performed using a perceptron layer.” Update for consistency.
-The position embedding is performed using a perceptron layer. The ViT encoder follows standard hyper-parameters as introduced in [2]. +The position embedding uses learnable parameters by default. The ViT encoder follows standard hyper-parameters as introduced in [2].
70-75
: Replace bare URLs with Markdown links and standardize “pretrained”.Improves rendering and consistency.
-https://developer.download.nvidia.com/assets/Clara/monai/research/UNETR_model_best_acc.pth +[UNETR_model_best_acc.pth](https://developer.download.nvidia.com/assets/Clara/monai/research/UNETR_model_best_acc.pth) @@ -`./pretrained_models` +`./pretrained_models`
95-101
: Same: link formatting for TorchScript model.-https://developer.download.nvidia.com/assets/Clara/monai/research/UNETR_model_best_acc.pt +[UNETR_model_best_acc.pt](https://developer.download.nvidia.com/assets/Clara/monai/research/UNETR_model_best_acc.pt) @@ -`./pretrained_models` +`./pretrained_models`
141-147
: Nice addition of predict mode. Add path notes for custom data.Mention that images for prediction should be placed under PREDICT_DATA_ROOT (config.py) to match the new loader.
python test.py --mode='predict' --infer_overlap=0.5 --pretrained_dir='./pretrained_models/' --saved_checkpoint=ckpt +# Images are read from PREDICT_DATA_ROOT defined in UNETR/BTCV/config.py.
171-175
: Typo: “Left Kideny” → “Left Kidney”; minor formatting.Also add spaces after numbers for readability.
-- Target: 13 abdominal organs including 1. Spleen 2. Right Kidney 3. Left Kideny 4.Gallbladder 5.Esophagus 6. Liver 7. Stomach 8.Aorta 9. IVC 10. Portal and Splenic Veins 11. Pancreas 12.Right adrenal gland 13.Left adrenal gland. +- Target: 13 abdominal organs including 1. Spleen 2. Right Kidney 3. Left Kidney 4. Gallbladder 5. Esophagus 6. Liver 7. Stomach 8. Aorta 9. IVC 10. Portal and Splenic Veins 11. Pancreas 12. Right adrenal gland 13. Left adrenal gland.
155-162
: Convert tutorial bare URLs to links.-https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/unetr_btcv_segmentation_3d.ipynb +[UNETR BTCV tutorial (native)](https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/unetr_btcv_segmentation_3d.ipynb) @@ -https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/unetr_btcv_segmentation_3d_lightning.ipynb +[UNETR BTCV tutorial (Lightning)](https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/unetr_btcv_segmentation_3d_lightning.ipynb)
176-180
: Convert dataset JSON URL to link.-https://developer.download.nvidia.com/assets/Clara/monai/tutorials/swin_unetr_btcv_dataset_0.json +[swin_unetr_btcv_dataset_0.json](https://developer.download.nvidia.com/assets/Clara/monai/tutorials/swin_unetr_btcv_dataset_0.json)
55-63
: Add a short “Custom NIfTI data” subsection referencing config.py.Docs currently focus on BTCV + --data_dir. Add where to place NIFTI_DATA_ROOT/NIFTI_LABEL_ROOT and how to train without --btcv.
Note that you need to provide the location of your dataset directory by using `--data_dir`. To initiate distributed multi-gpu training, `--distributed` needs to be added to the training command. To disable AMP, `--noamp` needs to be added to the training command. If UNETR is used in distributed multi-gpu training, we recommend increasing the learning rate (i.e. `--optim_lr`) according to the number of GPUs. For instance, `--optim_lr=4e-4` is recommended for training with 4 GPUs. + +#### Training on your own NIfTI data + +Set the following in `UNETR/BTCV/config.py`: + +- `NIFTI_DATA_ROOT`: folder containing images (`*.nii`/`*.nii.gz`) +- `NIFTI_LABEL_ROOT`: folder containing corresponding labels (matching filenames) + +Then run: + +```bash +python main.py +``` + +To train with the original BTCV json/dataloader instead, add `--btcv`.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (6)
UNETR/BTCV/README.md
(7 hunks)UNETR/BTCV/dataset/customDataset.py
(1 hunks)UNETR/BTCV/networks/unetr.py
(3 hunks)UNETR/BTCV/requirements.txt
(1 hunks)UNETR/BTCV/test.py
(3 hunks)UNETR/BTCV/trainer.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
- UNETR/BTCV/test.py
- UNETR/BTCV/trainer.py
- UNETR/BTCV/networks/unetr.py
🧰 Additional context used
🪛 markdownlint-cli2 (0.17.2)
UNETR/BTCV/README.md
7-7: Heading levels should only increment by one level at a time
Expected: h2; Actual: h3
(MD001, heading-increment)
70-70: Bare URL used
(MD034, no-bare-urls)
157-157: Bare URL used
(MD034, no-bare-urls)
161-161: Bare URL used
(MD034, no-bare-urls)
🪛 LanguageTool
UNETR/BTCV/README.md
[grammar] ~35-~35: There might be a mistake here.
Context: ...egmentation outputs. The network expects resampled input images with size `(96, 9...
(QB_NEW_EN)
[grammar] ~36-~36: There might be a mistake here.
Context: ...rlapping patches of size (16, 16, 16)
. The position embedding is performed usin...
(QB_NEW_EN)
[grammar] ~37-~37: There might be a mistake here.
Context: ...d hyper-parameters as introduced in [2]. The decoder uses convolutional and resid...
(QB_NEW_EN)
[grammar] ~40-~40: There might be a mistake here.
Context: ...ommand can be used to initiate training using PyTorch native AMP package: ```bash py...
(QB_NEW_EN)
[style] ~59-~59: You have already used this phrasing in nearby sentences. Consider replacing it to add variety to your writing.
Context: ...ing command. To disable AMP, --noamp
needs to be added to the training command. If U...
(REP_NEED_TO_VB)
[grammar] ~61-~61: There might be a mistake here.
Context: ...ng the learning rate (i.e. --optim_lr
) according to the number of GPUs. For ins...
(QB_NEW_EN)
[grammar] ~66-~66: There might be a mistake here.
Context: ...kpoints and TorchScript models of UNETR using BTCV dataset. For using the pre-traine...
(QB_NEW_EN)
[uncategorized] ~72-~72: Do not mix variants of the same word (‘pretrain’ and ‘pre-train’) within a single text.
Context: ...point in the following directory or use --pretrained_dir
to provide the address of where th...
(EN_WORD_COHERENCY)
[uncategorized] ~97-~97: Do not mix variants of the same word (‘pretrain’ and ‘pre-train’) within a single text.
Context: ...model in the following directory or use --pretrained_dir
to provide the address of where th...
(EN_WORD_COHERENCY)
[grammar] ~125-~125: There might be a mistake here.
Context: ...ript model in the following directory or use --pretrained_dir
to provide the ad...
(QB_NEW_EN)
[uncategorized] ~126-~126: Do not mix variants of the same word (‘pretrain’ and ‘pre-train’) within a single text.
Context: ...model in the following directory or use --pretrained_dir
to provide the address of where th...
(EN_WORD_COHERENCY)
[uncategorized] ~151-~151: Do not mix variants of the same word (‘pretrain’ and ‘pre-train’) within a single text.
Context: ...nce time. If you would like to use the pretrained TorchScript model, `--saved_checkpoint=...
(EN_WORD_COHERENCY)
[grammar] ~155-~155: There might be a mistake here.
Context: ...or the task of multi-organ segmentation using BTCV dataset can be found in the follow...
(QB_NEW_EN)
🪛 Ruff (0.12.2)
UNETR/BTCV/dataset/customDataset.py
92-92: Ambiguous variable name: l
(E741)
🔇 Additional comments (2)
UNETR/BTCV/requirements.txt (1)
6-7
: Compatibility of dependencies verified. MONAI 1.5.0 officially supports PyTorch 2.3–2.6 (i.e. torch>=2.3,<2.7) and relies on the chosen PyTorch wheel for CUDA; PyTorch 2.3–2.6 is fully compatible with NumPy 2.x. Ensure your CI installs the appropriate torch+CUDA wheel per PyTorch’s recommendations.UNETR/BTCV/dataset/customDataset.py (1)
102-116
: Good call adding EnsureTyped(track_meta=True) for predict.This preserves metadata for Invertd and safe .to(device) in test.py.
…ss image/label roots
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (3)
UNETR/BTCV/dataset/customDataset.py (3)
24-27
: Good fix: preserve image float and label long dtypes in collate.This addresses the earlier dtype corruption and avoids unnecessary copies.
31-39
: Robust image/label pairing via extension filtering + intersection looks solid.Nice guard against missing/extra files and mismatches.
120-121
: Good addition: EnsureTyped(track_meta=True) for predict pipeline.Enables Invertd and safe .to(device) without dropping metadata.
🧹 Nitpick comments (4)
UNETR/BTCV/dataset/customDataset.py (4)
31-38
: Add explicit directory existence checks for clearer errors.os.listdir will raise a less-informative exception if roots are missing. Fail fast with a friendly message.
Apply:
def getDatasetLoader(args): - exts = (".nii", ".nii.gz") + exts = (".nii", ".nii.gz") + if not os.path.isdir(NIFTI_DATA_ROOT): + raise FileNotFoundError(f"NIFTI_DATA_ROOT does not exist: {NIFTI_DATA_ROOT}") + if not os.path.isdir(NIFTI_LABEL_ROOT): + raise FileNotFoundError(f"NIFTI_LABEL_ROOT does not exist: {NIFTI_LABEL_ROOT}") img_names = {f for f in os.listdir(NIFTI_DATA_ROOT) if f.endswith(exts) and os.path.isfile(os.path.join(NIFTI_DATA_ROOT, f))} lbl_names = {f for f in os.listdir(NIFTI_LABEL_ROOT) if f.endswith(exts) and os.path.isfile(os.path.join(NIFTI_LABEL_ROOT, f))}
92-93
: Enable pinned memory and persistent workers for faster loading.Minor perf win on GPU and reduces worker respawn overhead.
- trainLoader = DataLoader(trainDataset,batch_size=args.batch_size,shuffle=True,num_workers=args.workers, collate_fn=_get_collate_fn(isTrain=True)) - valLoader = DataLoader(valDataset,batch_size=args.batch_size,shuffle=False,num_workers=args.workers, collate_fn=_get_collate_fn(isTrain=False)) + trainLoader = DataLoader( + trainDataset, + batch_size=args.batch_size, + shuffle=True, + num_workers=args.workers, + pin_memory=torch.cuda.is_available(), + persistent_workers=bool(args.workers), + collate_fn=_get_collate_fn(isTrain=True), + ) + valLoader = DataLoader( + valDataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.workers, + pin_memory=torch.cuda.is_available(), + persistent_workers=bool(args.workers), + collate_fn=_get_collate_fn(isTrain=False), + )
98-103
: Rename ambiguous variable ‘l’ and guard split edge-cases.Improves readability and satisfies lint (E741).
-def _splitList(l, trainRatio:float = 0.8): - totalNum = len(l) - splitIdx = int(totalNum * trainRatio) - - return l[:splitIdx], l[splitIdx :] +def _splitList(items, trainRatio: float = 0.8): + total = len(items) + splitIdx = int(total * trainRatio) + # ensure both splits non-empty when possible + if total > 1 and splitIdx == 0: + splitIdx = 1 + if total > 1 and splitIdx == total: + splitIdx = total - 1 + return items[:splitIdx], items[splitIdx:]
74-88
: Optional: Use EnsureTyped for val if you plan to invert transforms during validation.Only needed if validation pipeline applies Invertd to recover original spacing/origin for metrics/saving.
val_transform = transforms.Compose( [ transforms.LoadImaged(keys=["image", "label"]), transforms.EnsureChannelFirstd(keys=["image", "label"]), transforms.Orientationd(keys=["image", "label"], axcodes="RAS"), transforms.Spacingd( keys=["image", "label"], pixdim=(args.space_x, args.space_y, args.space_z), mode=("bilinear", "nearest") ), transforms.ScaleIntensityRanged( keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True ), transforms.CropForegroundd(keys=["image", "label"], source_key="image", allow_smaller=True), - transforms.ToTensord(keys=["image", "label"]), + transforms.EnsureTyped(keys=["image", "label"], track_meta=True), ] )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (1)
UNETR/BTCV/dataset/customDataset.py
(1 hunks)
🧰 Additional context used
🪛 Ruff (0.12.2)
UNETR/BTCV/dataset/customDataset.py
37-37: Avoid specifying long messages outside the exception class
(TRY003)
98-98: Ambiguous variable name: l
(E741)
🔇 Additional comments (1)
UNETR/BTCV/dataset/customDataset.py (1)
54-63
: Confirm effective batch size (batch_size × num_samples) fits memory.With RandCropByPosNegLabeld(num_samples=4) and custom collate, each DataLoader batch becomes batch_size*4 patches. Ensure trainer/optimizer expect this and your GPU can handle it, or reduce num_samples or batch_size.
Also applies to: 92-93
dataName = [d for d in os.listdir(PREDICT_DATA_ROOT)] | ||
dataDicts = [{"image": f"{os.path.join(PREDICT_DATA_ROOT, d)}" } for d in dataName] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Filter predict files to NIfTI and validate presence.
Unfiltered os.listdir may include non-NIfTI files (.DS_Store, JSON, etc.) and will break LoadImaged.
-def getPredictLoader(args):
- dataName = [d for d in os.listdir(PREDICT_DATA_ROOT)]
- dataDicts = [{"image": f"{os.path.join(PREDICT_DATA_ROOT, d)}" } for d in dataName]
+def getPredictLoader(args):
+ exts = (".nii", ".nii.gz")
+ if not os.path.isdir(PREDICT_DATA_ROOT):
+ raise FileNotFoundError(f"PREDICT_DATA_ROOT does not exist: {PREDICT_DATA_ROOT}")
+ files = sorted(
+ f for f in os.listdir(PREDICT_DATA_ROOT)
+ if f.endswith(exts) and os.path.isfile(os.path.join(PREDICT_DATA_ROOT, f))
+ )
+ if not files:
+ raise FileNotFoundError(f"No NIfTI files (.nii, .nii.gz) found in {PREDICT_DATA_ROOT}")
+ dataDicts = [{"image": os.path.join(PREDICT_DATA_ROOT, f)} for f in files]
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
dataName = [d for d in os.listdir(PREDICT_DATA_ROOT)] | |
dataDicts = [{"image": f"{os.path.join(PREDICT_DATA_ROOT, d)}" } for d in dataName] | |
def getPredictLoader(args): | |
exts = (".nii", ".nii.gz") | |
if not os.path.isdir(PREDICT_DATA_ROOT): | |
raise FileNotFoundError(f"PREDICT_DATA_ROOT does not exist: {PREDICT_DATA_ROOT}") | |
files = sorted( | |
f for f in os.listdir(PREDICT_DATA_ROOT) | |
if f.endswith(exts) and os.path.isfile(os.path.join(PREDICT_DATA_ROOT, f)) | |
) | |
if not files: | |
raise FileNotFoundError(f"No NIfTI files (.nii, .nii.gz) found in {PREDICT_DATA_ROOT}") | |
dataDicts = [{"image": os.path.join(PREDICT_DATA_ROOT, f)} for f in files] |
🤖 Prompt for AI Agents
In UNETR/BTCV/dataset/customDataset.py around lines 105 to 107, the code builds
dataName from os.listdir which can include non-NIfTI files and will break
LoadImaged; change it to only include files with NIfTI extensions (e.g., .nii,
.nii.gz) using a filter (or glob) and ensure each entry is a regular file, build
dataDicts from those paths, and add a validation step that logs/raises an error
if no valid NIfTI files are found.
Just put the images in PREDICT_DATA_ROOT defined in config.py and Use
python test.py --mode=predict
can output predict label in ./output folder.Put the images and labels in NIFTI_DATA_ROOT and NIFTI_LABEL_ROOT defined in config.py and Use
python main.py
can train with own data.python main.py --btcv
can train with original btcv dataset.Summary by CodeRabbit
New Features
Documentation
Refactor
Chores
Bug Fix