Skip to content

Conversation

tylin7111095022
Copy link

@tylin7111095022 tylin7111095022 commented Sep 2, 2025

  1. Add predict label function in test.py

Just put the images in PREDICT_DATA_ROOT defined in config.py and Use python test.py --mode=predict can output predict label in ./output folder.

  1. Add custom dataloader which can train with own data in dataset/customDataset.py .

Put the images and labels in NIFTI_DATA_ROOT and NIFTI_LABEL_ROOT defined in config.py and Use python main.py can train with own data.

  1. Use python main.py --btcv can train with original btcv dataset.

Summary by CodeRabbit

  • New Features

    • BTCV dataset support with end-to-end training/validation/prediction pipelines, configurable dirs, and sliding-window inference; per-class and mean Dice reporting; automatic output saving.
  • Documentation

    • README updated with clearer CLI examples, pretrained/finetune workflows, dataset details, tutorial links, and citation formatting.
  • Refactor

    • Position-embedding default changed to learnable; input/ROI defaults and training cadence adjusted (smaller ROIs, unit spacing, more frequent validation).
  • Chores

    • Dependency list modernized to recent imaging and ML package versions.
  • Bug Fix

    • Batch handling made compatible with tuple/list batch formats.

Copy link

coderabbitai bot commented Sep 2, 2025

Walkthrough

Adds BTCV dataset support with MONAI-based training/validation and prediction loaders, introduces new dataset/config constants, changes UNETR positional-embedding defaults and validation, updates training/inference flows (sliding-window inference, prediction saving), adjusts transforms/utilities and dependencies, and refreshes BTCV README and CLI examples.

Changes

Cohort / File(s) Summary
Documentation refresh
UNETR/BTCV/README.md
Large README edits: standardized code blocks/CLI examples, expanded BTCV dataset description and dataset JSON, pretrained/TorchScript flags and paths, citation/BibTeX and link updates.
Config constants
UNETR/BTCV/config.py
Added directory constants: NIFTI_DATA_ROOT, NIFTI_LABEL_ROOT, PREDICT_DATA_ROOT.
Dataset & loaders (MONAI)
UNETR/BTCV/dataset/customDataset.py
New MONAI pipelines and loaders: getDatasetLoader(args) returning train/val loaders with custom collate handling (flattening RandCrop outputs) and getPredictLoader(args) returning prediction loader and preTransform.
Training entrypoint changes
UNETR/BTCV/main.py
Added --btcv flag and getDatasetLoader selection, changed defaults (space -> 1.0, roi -> 64), --pos_embed default → learnable, reduced --val_every and --max_epochs, enabled --save_checkpoint default, and adjusted post-processing conversions.
Model pos-embed options
UNETR/BTCV/networks/unetr.py
Default pos_embed switched to learnable; "perceptron" aliased to "learnable"; accepted options now ['sincos','learnable','none']; ViT call uses pos_embed_type.
Inference & testing overhaul
UNETR/BTCV/test.py
New --mode (validation/predict), uses new loaders, added inference(inputs, model, args) using sliding-window inference + softmax, validation computes per-class and mean Dice, predict path saves outputs via Invertd/AsDiscreted/SaveImaged; model loading adjusted.
Trainer batch unpacking
UNETR/BTCV/trainer.py
train_epoch/val_epoch accept sequence (list/tuple) batches as (data, target) and fallback to dict keys when not a sequence.
Utilities & transforms
UNETR/BTCV/utils/data_utils.py
Replaced AddChanneld with EnsureChannelFirstd(channel_dim="no_channel"); added WORKROOT (Path) and JPG_EXT constants; added Path and SimpleITK imports.
Dependencies
UNETR/BTCV/requirements.txt
Updated dependency pins: monai → 1.5.0, torch relaxed to >=2.3,<2.7, added numpy/opencv_python/simpleitk/tensorboardx, updated nibabel/tensorboardx versions, removed several older pinned packages.
Tests / scripts
UNETR/BTCV/test.py
Added inference function and reworked validation/prediction flows to use MONAI sliding-window inference and MONAI post-transforms.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor User
  participant Main as main.py
  participant Loader as get_loader / getDatasetLoader
  participant Trainer as trainer.py
  participant Model as UNETR

  User->>Main: python main.py [--btcv ...]
  alt args.btcv == True
    Main->>Loader: getDatasetLoader(args)
  else
    Main->>Loader: get_loader(args)
  end
  Loader-->>Main: train_loader, val_loader
  Main->>Model: initialize UNETR (pos_embed='learnable'|'sincos'|'none')
  Main->>Trainer: train(train_loader, val_loader, Model, args)
  note over Trainer,Model: Validation cadence reduced (val_every lowered)
Loading
sequenceDiagram
  autonumber
  actor User
  participant Test as test.py
  participant Dset as getDatasetLoader/getPredictLoader
  participant Model as UNETR
  participant SWI as sliding_window_inference
  participant IO as SaveImaged

  User->>Test: python test.py --mode validation|predict
  alt mode == validation
    Test->>Dset: getDatasetLoader(args)
    Dset-->>Test: val_loader
    Test->>Model: load checkpoint
    loop batches
      Test->>SWI: inference(inputs, model, args)
      SWI-->>Test: logits
      Test->>Test: softmax -> argmax -> labels -> Dice
    end
    Test-->>User: Mean Dice results
  else mode == predict
    Test->>Dset: getPredictLoader(args)
    Dset-->>Test: pred_loader, preTransform
    Test->>Model: load checkpoint
    loop items
      Test->>SWI: inference(image, model, args)
      SWI-->>Test: logits
      Test->>IO: Invertd -> AsDiscreted -> SaveImaged
    end
    Test-->>User: Saved segmentations
  end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

I twitch my ears at datasets new,
With learnable stars in embedding’s view.
I hop through windows, sliding light,
Saving masks by moonlit night.
Dice bells ring—thump-thump—I cheer. 🐇✨

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore or @coderabbit ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@tylin7111095022 tylin7111095022 changed the title Add predict label function and custom dataloader which can train with own data . (UNETR) : Add predict label function and custom dataloader which can train with own data . Sep 2, 2025
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
UNETR/BTCV/networks/unetr.py (1)

88-99: Guard ViT instantiation on pos_embed_type
MONAI renamed the ViT kwarg from pos_embed (available in MONAI 1.0.1) to pos_embed_type (introduced in MONAI 1.3.0) (docs.monai.io). Use introspection on inspect.signature(ViT).parameters to pass pos_embed_type when available, otherwise fallback to pos_embed.
Apply:

-        self.vit = ViT(
-            in_channels=in_channels,
-            img_size=img_size,
-            patch_size=self.patch_size,
-            hidden_size=hidden_size,
-            mlp_dim=mlp_dim,
-            num_layers=self.num_layers,
-            num_heads=num_heads,
-            pos_embed_type=pos_embed,
-            classification=self.classification,
-            dropout_rate=dropout_rate,
-        )
+        import inspect
+        vit_args = dict(
+            in_channels=in_channels,
+            img_size=img_size,
+            patch_size=self.patch_size,
+            hidden_size=hidden_size,
+            mlp_dim=mlp_dim,
+            num_layers=self.num_layers,
+            num_heads=num_heads,
+            classification=self.classification,
+            dropout_rate=dropout_rate,
+        )
+        if "pos_embed_type" in inspect.signature(ViT).parameters:
+            vit_args["pos_embed_type"] = pos_embed
+        else:
+            vit_args["pos_embed"] = pos_embed
+        self.vit = ViT(**vit_args)
UNETR/BTCV/README.md (1)

19-33: Fix UNETR init snippet: pos_embed value outdated.

The code now supports ['sincos','learnable','none']; “perceptron” will raise.

Apply:

-    pos_embed='perceptron',
+    pos_embed='learnable',
🧹 Nitpick comments (15)
UNETR/BTCV/config.py (1)

1-3: Avoid hard-coded data roots; make them configurable via env vars.

This improves portability and lets users point to their own data without editing code.

Apply this diff (and add the import shown below) to allow overrides:

-NIFTI_DATA_ROOT = 'data/images' # nifti image directory
-NIFTI_LABEL_ROOT = 'data/labels' # nifti label directory
-PREDICT_DATA_ROOT = 'data/predict' # predict image directory
+NIFTI_DATA_ROOT = os.getenv("NIFTI_DATA_ROOT", "data/images")  # nifti image directory
+NIFTI_LABEL_ROOT = os.getenv("NIFTI_LABEL_ROOT", "data/labels")  # nifti label directory
+PREDICT_DATA_ROOT = os.getenv("PREDICT_DATA_ROOT", "data/predict")  # predict image directory

Outside the selected lines, add:

import os
UNETR/BTCV/requirements.txt (1)

5-5: Normalize package name casing.

The import is import tensorboardX but requirement lists tensorboardx. Pip is case-insensitive, but for consistency consider tensorboardX==2.6.4.

UNETR/BTCV/utils/data_utils.py (1)

16-17: Remove unused imports and constants.

Path, sitk, WORKROOT, and JPG_EXT aren’t used here.

Apply:

-from pathlib import Path
-import SimpleITK as sitk
+ # (removed unused imports)

-WORKROOT = Path(__file__).parent.parent
-JPG_EXT = '.jpg'
+ # (removed unused constants)

Also applies to: 22-23

UNETR/BTCV/networks/unetr.py (1)

61-66: Update examples to match new default/allowed values.

Replace pos_embed='perceptron' with 'learnable' and drop 'conv'.

Apply:

-            # for 4-channel input 3-channel output with patch size of (128,128,128), conv position embedding and instance norm
-            >>> net = UNETR(in_channels=4, out_channels=3, img_size=(128,128,128), pos_embed='conv', norm_name='instance')
+            # for 4-channel input 3-channel output with patch size of (128,128,128), learnable position embedding and instance norm
+            >>> net = UNETR(in_channels=4, out_channels=3, img_size=(128,128,128), pos_embed='learnable', norm_name='instance')
UNETR/BTCV/README.md (6)

42-53: Make CLI block copy-pasteable.

Use line continuations for multi-line shell commands.

Apply:

-```bash
-python main.py
---feature_size=32
---batch_size=1
---logdir=unetr_test
---fold=0
---optim_lr=1e-4
---lrschedule=warmup_cosine
---infer_overlap=0.5
---save_checkpoint
---data_dir=/dataset/dataset0/
-```
+```bash
+python main.py \
+  --feature_size=32 \
+  --batch_size=1 \
+  --logdir=unetr_test \
+  --fold=0 \
+  --optim_lr=1e-4 \
+  --lrschedule=warmup_cosine \
+  --infer_overlap=0.5 \
+  --save_checkpoint \
+  --data_dir=/dataset/dataset0/
+```

11-13: Use correct code fences.

These blocks are Python, not bash.

Apply:

-```bash
+```python

Also applies to: 19-33


130-147: Clarify predict mode data location.

Mention PREDICT_DATA_ROOT (config or env var) so users know where to place images.

Proposed addition after the predict command: “Place images under PREDICT_DATA_ROOT (default: data/predict) or set the env var PREDICT_DATA_ROOT.”


7-7: Fix heading level increment (MD001).

“Installing Dependencies” should be H2 to follow H1.

Apply:

-### Installing Dependencies
+## Installing Dependencies

171-171: Typo: “Left Kideny” → “Left Kidney”.

-- Target: 13 abdominal organs including 1. Spleen 2. Right Kidney 3. Left Kideny 4.Gallbladder ...
+- Target: 13 abdominal organs including 1. Spleen 2. Right Kidney 3. Left Kidney 4. Gallbladder ...

70-70: Avoid bare URLs; make them clickable text.

Improves readability and satisfies MD034.

Example:

-https://developer.download.nvidia.com/assets/Clara/monai/research/UNETR_model_best_acc.pth
+[Pretrained UNETR checkpoint (.pth)](https://developer.download.nvidia.com/assets/Clara/monai/research/UNETR_model_best_acc.pth)

Apply similarly to the TorchScript link and tutorial links.

Also applies to: 97-97, 157-157, 161-161

UNETR/BTCV/main.py (1)

46-46: Flag default True on --save_checkpoint; cannot disable.

Using action="store_true", default=True prevents users from turning it off. Prefer BooleanOptionalAction or an explicit --no-save-checkpoint.

-parser.add_argument("--save_checkpoint", action="store_true", default=True, help="save checkpoint during training")
+parser.add_argument(
+    "--save_checkpoint",
+    action=argparse.BooleanOptionalAction,
+    default=True,
+    help="save checkpoint during training",
+)
UNETR/BTCV/test.py (1)

109-121: Validation loop only measures first item if batch_size>1.

Indexing [0] ignores the rest of the batch. Iterate per-item to support batch>1.

-            for batch, label in loader:
-                val_inputs, val_labels = (batch.cuda(), label.cuda())
-                val_outputs = inference(val_inputs, model, args)
-                val_labels = val_labels.cpu().numpy()[:, 0, :, :, :]
-                dice_list_sub = []
-                for i in range(1, args.out_channels):
-                    every_Dice = dice(val_outputs[0] == i, val_labels[0] == i)
-                    dice_list_sub.append(every_Dice)
-                mean_dice = np.mean(dice_list_sub)
+            for batch, label in loader:
+                val_inputs, val_labels = (batch.cuda(), label.cuda())
+                val_outputs = inference(val_inputs, model, args)  # (B, H, W, D)
+                val_labels = val_labels.cpu().numpy()[:, 0, ...]  # (B, H, W, D)
+                batch_dice = []
+                for b in range(val_outputs.shape[0]):
+                    dice_list_sub = []
+                    for i in range(1, args.out_channels):
+                        dice_list_sub.append(dice(val_outputs[b] == i, val_labels[b] == i))
+                    batch_dice.append(np.mean(dice_list_sub))
+                mean_dice = float(np.mean(batch_dice))
                 print("Mean Dice: {}".format(mean_dice))
                 dice_list_case.append(mean_dice)
UNETR/BTCV/dataset/customDataset.py (3)

32-34: Deterministic pairing and basic filtering.

Sort file lists and optionally filter valid extensions to keep image/label alignment stable across runs.

-    dataName = [d for d in os.listdir(NIFTI_LABEL_ROOT)]
+    dataName = sorted(d for d in os.listdir(NIFTI_LABEL_ROOT) if d.endswith((".nii", ".nii.gz")))

86-88: DataLoader perf knobs.

Consider pin_memory and persistent_workers for faster host→GPU transfer.

-    trainLoader = DataLoader(trainDataset,batch_size=args.batch_size,shuffle=True,num_workers=args.workers, collate_fn=_get_collate_fn(isTrain=True))
-    valLoader = DataLoader(valDataset,batch_size=args.batch_size,shuffle=False,num_workers=args.workers, collate_fn=_get_collate_fn(isTrain=False))
+    trainLoader = DataLoader(
+        trainDataset,
+        batch_size=args.batch_size,
+        shuffle=True,
+        num_workers=args.workers,
+        pin_memory=True,
+        persistent_workers=True,
+        collate_fn=_get_collate_fn(isTrain=True),
+    )
+    valLoader = DataLoader(
+        valDataset,
+        batch_size=args.batch_size,
+        shuffle=False,
+        num_workers=args.workers,
+        pin_memory=True,
+        persistent_workers=True,
+        collate_fn=_get_collate_fn(isTrain=False),
+    )

92-97: Rename ambiguous parameter l.

Avoid single-letter “l” (E741). Improves readability.

-def _splitList(l, trainRatio:float = 0.8):
-    totalNum = len(l)
+def _splitList(items, trainRatio: float = 0.8):
+    totalNum = len(items)
     splitIdx = int(totalNum * trainRatio)
 
-    return l[:splitIdx], l[splitIdx :]
+    return items[:splitIdx], items[splitIdx:]
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 21ed8e5 and f0b3dc4.

📒 Files selected for processing (9)
  • UNETR/BTCV/README.md (7 hunks)
  • UNETR/BTCV/config.py (1 hunks)
  • UNETR/BTCV/dataset/customDataset.py (1 hunks)
  • UNETR/BTCV/main.py (4 hunks)
  • UNETR/BTCV/networks/unetr.py (2 hunks)
  • UNETR/BTCV/requirements.txt (1 hunks)
  • UNETR/BTCV/test.py (3 hunks)
  • UNETR/BTCV/trainer.py (1 hunks)
  • UNETR/BTCV/utils/data_utils.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
UNETR/BTCV/test.py (3)
UNETR/BTCV/networks/unetr.py (1)
  • UNETR (22-230)
UNETR/BTCV/trainer.py (1)
  • dice (27-33)
UNETR/BTCV/dataset/customDataset.py (2)
  • getDatasetLoader (31-90)
  • getPredictLoader (98-120)
UNETR/BTCV/main.py (5)
UNETR/BTCV/networks/unetr.py (1)
  • UNETR (22-230)
UNETR/BTCV/optimizers/lr_scheduler.py (1)
  • LinearWarmupCosineAnnealingLR (92-172)
UNETR/BTCV/trainer.py (1)
  • run_training (150-233)
UNETR/BTCV/utils/data_utils.py (1)
  • get_loader (72-168)
UNETR/BTCV/dataset/customDataset.py (1)
  • getDatasetLoader (31-90)
🪛 Ruff (0.12.2)
UNETR/BTCV/dataset/customDataset.py

92-92: Ambiguous variable name: l

(E741)

🪛 markdownlint-cli2 (0.17.2)
UNETR/BTCV/README.md

7-7: Heading levels should only increment by one level at a time
Expected: h2; Actual: h3

(MD001, heading-increment)


70-70: Bare URL used

(MD034, no-bare-urls)


157-157: Bare URL used

(MD034, no-bare-urls)


161-161: Bare URL used

(MD034, no-bare-urls)

🪛 LanguageTool
UNETR/BTCV/README.md

[grammar] ~36-~36: There might be a mistake here.
Context: ...rlapping patches of size (16, 16, 16). The position embedding is performed usin...

(QB_NEW_EN)


[grammar] ~37-~37: There might be a mistake here.
Context: ...d hyper-parameters as introduced in [2]. The decoder uses convolutional and resid...

(QB_NEW_EN)


[grammar] ~40-~40: There might be a mistake here.
Context: ...ommand can be used to initiate training using PyTorch native AMP package: ```bash py...

(QB_NEW_EN)


[style] ~59-~59: You have already used this phrasing in nearby sentences. Consider replacing it to add variety to your writing.
Context: ...ing command. To disable AMP, --noamp needs to be added to the training command. If U...

(REP_NEED_TO_VB)


[grammar] ~61-~61: There might be a mistake here.
Context: ...ng the learning rate (i.e. --optim_lr) according to the number of GPUs. For ins...

(QB_NEW_EN)


[grammar] ~66-~66: There might be a mistake here.
Context: ...kpoints and TorchScript models of UNETR using BTCV dataset. For using the pre-traine...

(QB_NEW_EN)


[uncategorized] ~72-~72: Do not mix variants of the same word (‘pretrain’ and ‘pre-train’) within a single text.
Context: ...point in the following directory or use --pretrained_dir to provide the address of where th...

(EN_WORD_COHERENCY)


[uncategorized] ~97-~97: Do not mix variants of the same word (‘pretrain’ and ‘pre-train’) within a single text.
Context: ...model in the following directory or use --pretrained_dir to provide the address of where th...

(EN_WORD_COHERENCY)


[grammar] ~125-~125: There might be a mistake here.
Context: ...ript model in the following directory or use --pretrained_dir to provide the ad...

(QB_NEW_EN)


[uncategorized] ~126-~126: Do not mix variants of the same word (‘pretrain’ and ‘pre-train’) within a single text.
Context: ...model in the following directory or use --pretrained_dir to provide the address of where th...

(EN_WORD_COHERENCY)


[uncategorized] ~151-~151: Do not mix variants of the same word (‘pretrain’ and ‘pre-train’) within a single text.
Context: ...nce time. If you would like to use the pretrained TorchScript model, `--saved_checkpoint=...

(EN_WORD_COHERENCY)


[grammar] ~155-~155: There might be a mistake here.
Context: ...or the task of multi-organ segmentation using BTCV dataset can be found in the follow...

(QB_NEW_EN)

🔇 Additional comments (5)
UNETR/BTCV/utils/data_utils.py (2)

78-79: Channel handling change looks good; confirm label shape assumptions.

Switching to EnsureChannelFirstd(..., channel_dim="no_channel") is fine. Ensure downstream expects labels with shape [1, D, H, W].

Also applies to: 109-110


77-79: Remove ITKReader override: nibabel is still declared in DAE/BTCV_Finetune/requirements.txt, so the default loader will handle .nii/.nii.gz correctly.

UNETR/BTCV/main.py (2)

123-124: Mixed batch formats between BTCV and custom loaders — ensure val_epoch fix is applied.

With args.btcv True, batches are dicts; with custom, they’re [images, labels]. This relies on trainer.val_epoch handling both. Please apply the trainer.py fix accordingly.


162-164: Post-processing aligns with MONAI metrics.

AsDiscrete(to_onehot=out_channels) and (argmax=True, to_onehot=out_channels) are appropriate for DiceMetric with include_background=True.

UNETR/BTCV/test.py (1)

101-103: torch.load keyword compatibility.

weights_only was introduced in newer PyTorch. Your call uses weights_only=False, which can raise TypeError on older versions. Guard it for wider compatibility.

-        model_dict = torch.load(pretrained_pth, weights_only=False)
+        try:
+            model_dict = torch.load(pretrained_pth, weights_only=False)
+        except TypeError:
+            model_dict = torch.load(pretrained_pth)

tylin7111095022 and others added 3 commits September 3, 2025 09:39
- `dataset/customDataset.py`
  - Ensure data type of image/label match float/int required by
    CE/one-hot
  - Add transforms.EnsureTyped in preTransform pipeline
- `networks/unetr.py`, `README.md`
  - Change default param of pos_embe to `learnable`
- `requirements.txt`
  - Assign the range of torch version and add nibabel version
- `test.py`
  - Refactor the code
- `trainer.py`
  - Handle unpack data process by loader
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
UNETR/BTCV/requirements.txt (1)

1-7: Pins look sane; past concerns addressed.

Good: MONAI 1.5.0, PyTorch constrained, NiBabel present. This addresses the prior “pin torch + add NiFTI reader” feedback.

🧹 Nitpick comments (14)
UNETR/BTCV/requirements.txt (1)

3-3: Prefer headless OpenCV for server/CI environments.

Switch to opencv-python-headless to avoid unnecessary GUI deps and X11 issues.

-opencv_python
+opencv-python-headless
UNETR/BTCV/dataset/customDataset.py (4)

92-97: Rename ambiguous variable and sort for deterministic split.

Avoid “l”; sort for reproducibility instead of relying on filesystem order.

-def _splitList(l, trainRatio:float = 0.8):
-    totalNum = len(l)
-    splitIdx = int(totalNum * trainRatio)
-
-    return l[:splitIdx], l[splitIdx :]
+def _splitList(items, trainRatio: float = 0.8):
+    items = sorted(items)
+    total = len(items)
+    split = int(total * trainRatio)
+    return items[:split], items[split:]

86-87: Enable pinned memory (and optional prefetch) for faster GPU transfer.

Dataloaders feeding CUDA benefit from pin_memory; prefetch_factor helps when num_workers > 0.

-trainLoader = DataLoader(trainDataset,batch_size=args.batch_size,shuffle=True,num_workers=args.workers, collate_fn=_get_collate_fn(isTrain=True))
-valLoader = DataLoader(valDataset,batch_size=args.batch_size,shuffle=False,num_workers=args.workers, collate_fn=_get_collate_fn(isTrain=False))
+trainLoader = DataLoader(
+    trainDataset,
+    batch_size=args.batch_size,
+    shuffle=True,
+    num_workers=args.workers,
+    pin_memory=True,
+    prefetch_factor=2 if args.workers and args.workers > 0 else None,
+    collate_fn=_get_collate_fn(isTrain=True),
+)
+valLoader = DataLoader(
+    valDataset,
+    batch_size=args.batch_size,
+    shuffle=False,
+    num_workers=args.workers,
+    pin_memory=True,
+    prefetch_factor=2 if args.workers and args.workers > 0 else None,
+    collate_fn=_get_collate_fn(isTrain=False),
+)

48-57: Expose num_samples via args to avoid hard-coding.

Let users tune crop density without editing code.

-            transforms.RandCropByPosNegLabeld(
+            transforms.RandCropByPosNegLabeld(
                 keys=["image", "label"],
                 label_key="label",
                 spatial_size=(args.roi_x, args.roi_y, args.roi_z),
                 pos=1,
                 neg=1,
-                num_samples=4,
+                num_samples=getattr(args, "num_samples", 4),
                 image_key="image",
                 image_threshold=0,
             ),

9-29: Consider using MONAI’s list_data_collate to simplify collate_fn.

list_data_collate flattens RandCropByPosNegLabeld outputs and stacks tensors. If you keep custom collate, returning a tuple (images, labels) is more idiomatic than a list.

-        return [images.float(), labels.long()]
+        return images.float(), labels.long()

Alternative:

from monai.data import list_data_collate
# ...
trainLoader = DataLoader(trainDataset, ..., collate_fn=list_data_collate)
valLoader = DataLoader(valDataset, ..., collate_fn=list_data_collate)
UNETR/BTCV/README.md (9)

8-8: Fix heading level (MD001).

Use H2 after an H1.

-### Installing Dependencies
+## Installing Dependencies

28-29: Align text with the new default pos_embed=‘learnable’.

Below you still say “position embedding is performed using a perceptron layer.” Update for consistency.

-The position embedding is performed using a perceptron layer. The ViT encoder follows standard hyper-parameters as introduced in [2].
+The position embedding uses learnable parameters by default. The ViT encoder follows standard hyper-parameters as introduced in [2].

70-75: Replace bare URLs with Markdown links and standardize “pretrained”.

Improves rendering and consistency.

-https://developer.download.nvidia.com/assets/Clara/monai/research/UNETR_model_best_acc.pth
+[UNETR_model_best_acc.pth](https://developer.download.nvidia.com/assets/Clara/monai/research/UNETR_model_best_acc.pth)
@@
-`./pretrained_models`
+`./pretrained_models`

95-101: Same: link formatting for TorchScript model.

-https://developer.download.nvidia.com/assets/Clara/monai/research/UNETR_model_best_acc.pt
+[UNETR_model_best_acc.pt](https://developer.download.nvidia.com/assets/Clara/monai/research/UNETR_model_best_acc.pt)
@@
-`./pretrained_models`
+`./pretrained_models`

141-147: Nice addition of predict mode. Add path notes for custom data.

Mention that images for prediction should be placed under PREDICT_DATA_ROOT (config.py) to match the new loader.

 python test.py
 --mode='predict'
 --infer_overlap=0.5
 --pretrained_dir='./pretrained_models/'
 --saved_checkpoint=ckpt
+# Images are read from PREDICT_DATA_ROOT defined in UNETR/BTCV/config.py.

171-175: Typo: “Left Kideny” → “Left Kidney”; minor formatting.

Also add spaces after numbers for readability.

-- Target: 13 abdominal organs including 1. Spleen 2. Right Kidney 3. Left Kideny 4.Gallbladder 5.Esophagus 6. Liver 7. Stomach 8.Aorta 9. IVC 10. Portal and Splenic Veins 11. Pancreas 12.Right adrenal gland 13.Left adrenal gland.
+- Target: 13 abdominal organs including 1. Spleen 2. Right Kidney 3. Left Kidney 4. Gallbladder 5. Esophagus 6. Liver 7. Stomach 8. Aorta 9. IVC 10. Portal and Splenic Veins 11. Pancreas 12. Right adrenal gland 13. Left adrenal gland.

155-162: Convert tutorial bare URLs to links.

-https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/unetr_btcv_segmentation_3d.ipynb
+[UNETR BTCV tutorial (native)](https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/unetr_btcv_segmentation_3d.ipynb)
@@
-https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/unetr_btcv_segmentation_3d_lightning.ipynb
+[UNETR BTCV tutorial (Lightning)](https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/unetr_btcv_segmentation_3d_lightning.ipynb)

176-180: Convert dataset JSON URL to link.

-https://developer.download.nvidia.com/assets/Clara/monai/tutorials/swin_unetr_btcv_dataset_0.json
+[swin_unetr_btcv_dataset_0.json](https://developer.download.nvidia.com/assets/Clara/monai/tutorials/swin_unetr_btcv_dataset_0.json)

55-63: Add a short “Custom NIfTI data” subsection referencing config.py.

Docs currently focus on BTCV + --data_dir. Add where to place NIFTI_DATA_ROOT/NIFTI_LABEL_ROOT and how to train without --btcv.

 Note that you need to provide the location of your dataset directory by using `--data_dir`.
 
 To initiate distributed multi-gpu training, `--distributed` needs to be added to the training command.
 
 To disable AMP, `--noamp` needs to be added to the training command.
 
 If UNETR is used in distributed multi-gpu training, we recommend increasing the learning rate (i.e. `--optim_lr`)
 according to the number of GPUs. For instance, `--optim_lr=4e-4` is recommended for training with 4 GPUs.
+
+#### Training on your own NIfTI data
+
+Set the following in `UNETR/BTCV/config.py`:
+
+- `NIFTI_DATA_ROOT`: folder containing images (`*.nii`/`*.nii.gz`)
+- `NIFTI_LABEL_ROOT`: folder containing corresponding labels (matching filenames)
+
+Then run:
+
+```bash
+python main.py
+```
+
+To train with the original BTCV json/dataloader instead, add `--btcv`.
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between f0b3dc4 and 4b366df.

📒 Files selected for processing (6)
  • UNETR/BTCV/README.md (7 hunks)
  • UNETR/BTCV/dataset/customDataset.py (1 hunks)
  • UNETR/BTCV/networks/unetr.py (3 hunks)
  • UNETR/BTCV/requirements.txt (1 hunks)
  • UNETR/BTCV/test.py (3 hunks)
  • UNETR/BTCV/trainer.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
  • UNETR/BTCV/test.py
  • UNETR/BTCV/trainer.py
  • UNETR/BTCV/networks/unetr.py
🧰 Additional context used
🪛 markdownlint-cli2 (0.17.2)
UNETR/BTCV/README.md

7-7: Heading levels should only increment by one level at a time
Expected: h2; Actual: h3

(MD001, heading-increment)


70-70: Bare URL used

(MD034, no-bare-urls)


157-157: Bare URL used

(MD034, no-bare-urls)


161-161: Bare URL used

(MD034, no-bare-urls)

🪛 LanguageTool
UNETR/BTCV/README.md

[grammar] ~35-~35: There might be a mistake here.
Context: ...egmentation outputs. The network expects resampled input images with size `(96, 9...

(QB_NEW_EN)


[grammar] ~36-~36: There might be a mistake here.
Context: ...rlapping patches of size (16, 16, 16). The position embedding is performed usin...

(QB_NEW_EN)


[grammar] ~37-~37: There might be a mistake here.
Context: ...d hyper-parameters as introduced in [2]. The decoder uses convolutional and resid...

(QB_NEW_EN)


[grammar] ~40-~40: There might be a mistake here.
Context: ...ommand can be used to initiate training using PyTorch native AMP package: ```bash py...

(QB_NEW_EN)


[style] ~59-~59: You have already used this phrasing in nearby sentences. Consider replacing it to add variety to your writing.
Context: ...ing command. To disable AMP, --noamp needs to be added to the training command. If U...

(REP_NEED_TO_VB)


[grammar] ~61-~61: There might be a mistake here.
Context: ...ng the learning rate (i.e. --optim_lr) according to the number of GPUs. For ins...

(QB_NEW_EN)


[grammar] ~66-~66: There might be a mistake here.
Context: ...kpoints and TorchScript models of UNETR using BTCV dataset. For using the pre-traine...

(QB_NEW_EN)


[uncategorized] ~72-~72: Do not mix variants of the same word (‘pretrain’ and ‘pre-train’) within a single text.
Context: ...point in the following directory or use --pretrained_dir to provide the address of where th...

(EN_WORD_COHERENCY)


[uncategorized] ~97-~97: Do not mix variants of the same word (‘pretrain’ and ‘pre-train’) within a single text.
Context: ...model in the following directory or use --pretrained_dir to provide the address of where th...

(EN_WORD_COHERENCY)


[grammar] ~125-~125: There might be a mistake here.
Context: ...ript model in the following directory or use --pretrained_dir to provide the ad...

(QB_NEW_EN)


[uncategorized] ~126-~126: Do not mix variants of the same word (‘pretrain’ and ‘pre-train’) within a single text.
Context: ...model in the following directory or use --pretrained_dir to provide the address of where th...

(EN_WORD_COHERENCY)


[uncategorized] ~151-~151: Do not mix variants of the same word (‘pretrain’ and ‘pre-train’) within a single text.
Context: ...nce time. If you would like to use the pretrained TorchScript model, `--saved_checkpoint=...

(EN_WORD_COHERENCY)


[grammar] ~155-~155: There might be a mistake here.
Context: ...or the task of multi-organ segmentation using BTCV dataset can be found in the follow...

(QB_NEW_EN)

🪛 Ruff (0.12.2)
UNETR/BTCV/dataset/customDataset.py

92-92: Ambiguous variable name: l

(E741)

🔇 Additional comments (2)
UNETR/BTCV/requirements.txt (1)

6-7: Compatibility of dependencies verified. MONAI 1.5.0 officially supports PyTorch 2.3–2.6 (i.e. torch>=2.3,<2.7) and relies on the chosen PyTorch wheel for CUDA; PyTorch 2.3–2.6 is fully compatible with NumPy 2.x. Ensure your CI installs the appropriate torch+CUDA wheel per PyTorch’s recommendations.

UNETR/BTCV/dataset/customDataset.py (1)

102-116: Good call adding EnsureTyped(track_meta=True) for predict.

This preserves metadata for Invertd and safe .to(device) in test.py.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (3)
UNETR/BTCV/dataset/customDataset.py (3)

24-27: Good fix: preserve image float and label long dtypes in collate.

This addresses the earlier dtype corruption and avoids unnecessary copies.


31-39: Robust image/label pairing via extension filtering + intersection looks solid.

Nice guard against missing/extra files and mismatches.


120-121: Good addition: EnsureTyped(track_meta=True) for predict pipeline.

Enables Invertd and safe .to(device) without dropping metadata.

🧹 Nitpick comments (4)
UNETR/BTCV/dataset/customDataset.py (4)

31-38: Add explicit directory existence checks for clearer errors.

os.listdir will raise a less-informative exception if roots are missing. Fail fast with a friendly message.

Apply:

 def getDatasetLoader(args):
-    exts = (".nii", ".nii.gz")
+    exts = (".nii", ".nii.gz")
+    if not os.path.isdir(NIFTI_DATA_ROOT):
+        raise FileNotFoundError(f"NIFTI_DATA_ROOT does not exist: {NIFTI_DATA_ROOT}")
+    if not os.path.isdir(NIFTI_LABEL_ROOT):
+        raise FileNotFoundError(f"NIFTI_LABEL_ROOT does not exist: {NIFTI_LABEL_ROOT}")
     img_names = {f for f in os.listdir(NIFTI_DATA_ROOT) if f.endswith(exts) and os.path.isfile(os.path.join(NIFTI_DATA_ROOT, f))}
     lbl_names = {f for f in os.listdir(NIFTI_LABEL_ROOT) if f.endswith(exts) and os.path.isfile(os.path.join(NIFTI_LABEL_ROOT, f))}

92-93: Enable pinned memory and persistent workers for faster loading.

Minor perf win on GPU and reduces worker respawn overhead.

-    trainLoader = DataLoader(trainDataset,batch_size=args.batch_size,shuffle=True,num_workers=args.workers, collate_fn=_get_collate_fn(isTrain=True))
-    valLoader = DataLoader(valDataset,batch_size=args.batch_size,shuffle=False,num_workers=args.workers, collate_fn=_get_collate_fn(isTrain=False))
+    trainLoader = DataLoader(
+        trainDataset,
+        batch_size=args.batch_size,
+        shuffle=True,
+        num_workers=args.workers,
+        pin_memory=torch.cuda.is_available(),
+        persistent_workers=bool(args.workers),
+        collate_fn=_get_collate_fn(isTrain=True),
+    )
+    valLoader = DataLoader(
+        valDataset,
+        batch_size=args.batch_size,
+        shuffle=False,
+        num_workers=args.workers,
+        pin_memory=torch.cuda.is_available(),
+        persistent_workers=bool(args.workers),
+        collate_fn=_get_collate_fn(isTrain=False),
+    )

98-103: Rename ambiguous variable ‘l’ and guard split edge-cases.

Improves readability and satisfies lint (E741).

-def _splitList(l, trainRatio:float = 0.8):
-    totalNum = len(l)
-    splitIdx = int(totalNum * trainRatio)
-
-    return l[:splitIdx], l[splitIdx :]
+def _splitList(items, trainRatio: float = 0.8):
+    total = len(items)
+    splitIdx = int(total * trainRatio)
+    # ensure both splits non-empty when possible
+    if total > 1 and splitIdx == 0:
+        splitIdx = 1
+    if total > 1 and splitIdx == total:
+        splitIdx = total - 1
+    return items[:splitIdx], items[splitIdx:]

74-88: Optional: Use EnsureTyped for val if you plan to invert transforms during validation.

Only needed if validation pipeline applies Invertd to recover original spacing/origin for metrics/saving.

     val_transform = transforms.Compose(
             [
                 transforms.LoadImaged(keys=["image", "label"]),
                 transforms.EnsureChannelFirstd(keys=["image", "label"]),
                 transforms.Orientationd(keys=["image", "label"], axcodes="RAS"),
                 transforms.Spacingd(
                     keys=["image", "label"], pixdim=(args.space_x, args.space_y, args.space_z), mode=("bilinear", "nearest")
                 ),
                 transforms.ScaleIntensityRanged(
                     keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True
                 ),
                 transforms.CropForegroundd(keys=["image", "label"], source_key="image", allow_smaller=True),
-                transforms.ToTensord(keys=["image", "label"]),
+                transforms.EnsureTyped(keys=["image", "label"], track_meta=True),
             ]
         )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 4b366df and 30bc5b1.

📒 Files selected for processing (1)
  • UNETR/BTCV/dataset/customDataset.py (1 hunks)
🧰 Additional context used
🪛 Ruff (0.12.2)
UNETR/BTCV/dataset/customDataset.py

37-37: Avoid specifying long messages outside the exception class

(TRY003)


98-98: Ambiguous variable name: l

(E741)

🔇 Additional comments (1)
UNETR/BTCV/dataset/customDataset.py (1)

54-63: Confirm effective batch size (batch_size × num_samples) fits memory.

With RandCropByPosNegLabeld(num_samples=4) and custom collate, each DataLoader batch becomes batch_size*4 patches. Ensure trainer/optimizer expect this and your GPU can handle it, or reduce num_samples or batch_size.

Also applies to: 92-93

Comment on lines +105 to +107
dataName = [d for d in os.listdir(PREDICT_DATA_ROOT)]
dataDicts = [{"image": f"{os.path.join(PREDICT_DATA_ROOT, d)}" } for d in dataName]

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Filter predict files to NIfTI and validate presence.

Unfiltered os.listdir may include non-NIfTI files (.DS_Store, JSON, etc.) and will break LoadImaged.

-def getPredictLoader(args):
-    dataName = [d for d in os.listdir(PREDICT_DATA_ROOT)]
-    dataDicts = [{"image": f"{os.path.join(PREDICT_DATA_ROOT, d)}" } for d in dataName]
+def getPredictLoader(args):
+    exts = (".nii", ".nii.gz")
+    if not os.path.isdir(PREDICT_DATA_ROOT):
+        raise FileNotFoundError(f"PREDICT_DATA_ROOT does not exist: {PREDICT_DATA_ROOT}")
+    files = sorted(
+        f for f in os.listdir(PREDICT_DATA_ROOT)
+        if f.endswith(exts) and os.path.isfile(os.path.join(PREDICT_DATA_ROOT, f))
+    )
+    if not files:
+        raise FileNotFoundError(f"No NIfTI files (.nii, .nii.gz) found in {PREDICT_DATA_ROOT}")
+    dataDicts = [{"image": os.path.join(PREDICT_DATA_ROOT, f)} for f in files]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
dataName = [d for d in os.listdir(PREDICT_DATA_ROOT)]
dataDicts = [{"image": f"{os.path.join(PREDICT_DATA_ROOT, d)}" } for d in dataName]
def getPredictLoader(args):
exts = (".nii", ".nii.gz")
if not os.path.isdir(PREDICT_DATA_ROOT):
raise FileNotFoundError(f"PREDICT_DATA_ROOT does not exist: {PREDICT_DATA_ROOT}")
files = sorted(
f for f in os.listdir(PREDICT_DATA_ROOT)
if f.endswith(exts) and os.path.isfile(os.path.join(PREDICT_DATA_ROOT, f))
)
if not files:
raise FileNotFoundError(f"No NIfTI files (.nii, .nii.gz) found in {PREDICT_DATA_ROOT}")
dataDicts = [{"image": os.path.join(PREDICT_DATA_ROOT, f)} for f in files]
🤖 Prompt for AI Agents
In UNETR/BTCV/dataset/customDataset.py around lines 105 to 107, the code builds
dataName from os.listdir which can include non-NIfTI files and will break
LoadImaged; change it to only include files with NIfTI extensions (e.g., .nii,
.nii.gz) using a filter (or glob) and ensure each entry is a regular file, build
dataDicts from those paths, and add a validation step that logs/raises an error
if no valid NIfTI files are found.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant