Skip to content

Commit 8c6c306

Browse files
C-AchardMMathisLab
andauthored
Colab notebooks update (#103)
* Fix use of deprecated arg in colab training * Refactor model save name path + comment wandb cell * Update Colab_WNet3D_training.ipynb * Improve logging in Colab * Subclass WnetTraininWorker to avoid duplication * Remove strict channel first * Add missing channel_dim, remove strict_check=False * Update worker_training.py * Update worker_training.py * Disable strict checks for channelfirstd * Update worker_training.py * Temp disable channel first * Fix init of Colab worker * Move issues with transforms to colab script + disable pad/channelfirst * Enable ChannelFirst again * Remove strict_check = False in original worker Seems to be a Colab-specific issue * Remove redundant code + Colab notebook tweaks * Revert wandb check * Update docs + Colab inference * Update training_wnet.rst * Update Colab_WNet3D_training.ipynb * update / WIP * Update Colab_inference_demo.ipynb * Update Colab_inference_demo.ipynb * Update Colab_inference_demo.ipynb * Update Colab_inference_demo.ipynb * Update Colab_inference_demo.ipynb * Update Colab_inference_demo.ipynb * Update Colab_inference_demo.ipynb * Update Colab_inference_demo.ipynb * nearly final! * exec * final --------- Co-authored-by: Mackenzie Mathis <[email protected]> Co-authored-by: Mackenzie Mathis <[email protected]>
1 parent bb806f0 commit 8c6c306

File tree

5 files changed

+689
-1262
lines changed

5 files changed

+689
-1262
lines changed

docs/source/guides/training_wnet.rst

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,21 @@ The WNet3D **does not require a large amount of data to train**, but **choosing
1818

1919
You may find below some guidelines, based on our own data and testing.
2020

21-
The WNet3D is designed to segment objects based on their brightness, and is particularly well-suited for images with a clear contrast between objects and background.
22-
23-
The WNet3D is not suitable for images with artifacts, therefore care should be taken that the images are clean and that the objects are at least somewhat distinguishable from the background.
21+
The WNet3D is a self-supervised learning approach for 3D cell segmentation, and relies on the assumption that structural and morphological features of cells can be inferred directly from unlabeled data. This involves leveraging inherent properties such as spatial coherence and local contrast in imaging volumes to distinguish cellular structures. This approach assumes that meaningful representations of cellular boundaries and nuclei can emerge solely from raw 3D volumes. Thus, we strongly recommend that you use WNet3D on stacks that have clear foreground/background segregation and limited noise. Even if your final samples have noise, it is best to train on data that is as clean as you can.
2422

2523

2624
.. important::
2725
For optimal performance, the following should be avoided for training:
2826

29-
- Images with very large, bright regions
30-
- Almost-empty and empty images
31-
- Images with large empty regions or "holes"
27+
- Images with over-exposed pixels/artifacts you do not want to be learned!
28+
- Almost-empty and/or fully empty images, especially if noise is present (it will learn to segment very small objects!).
3229

33-
However, the model may be accomodate:
30+
However, the model may accomodate:
3431

35-
- Uneven brightness distribution
36-
- Varied object shapes and radius
37-
- Noisy images
38-
- Uneven illumination across the image
32+
- Uneven brightness distribution in your image!
33+
- Varied object shapes and radius!
34+
- Noisy images (as long as resolution is sufficient and boundaries are clear)!
35+
- Uneven illumination across the image!
3936

4037
For optimal results, during inference, images should be similar to those the model was trained on; however this is not a strict requirement.
4138

@@ -88,7 +85,7 @@ Common issues troubleshooting
8885
If you do not find a satisfactory answer here, please do not hesitate to `open an issue`_ on GitHub.
8986

9087

91-
- **The NCuts loss "explodes" after a few epochs** : Lower the learning rate, for example start with a factor of two, then ten.
88+
- **The NCuts loss "explodes" upward after a few epochs** : Lower the learning rate, for example start with a factor of two, then ten.
9289

9390
- **Reconstruction (decoder) performance is poor** : First, try increasing the weight of the reconstruction loss. If this is ineffective, switch to BCE loss and set the scaling factor of the reconstruction loss to 0.5, OR adjust the weight of the MSE loss.
9491

napari_cellseg3d/code_models/worker_training.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Contains the workers used to train the models."""
2+
23
import platform
34
import time
45
from abc import abstractmethod
@@ -200,7 +201,10 @@ def get_patch_dataset(self, train_transforms):
200201
patch_func = Compose(
201202
[
202203
LoadImaged(keys=["image"], image_only=True),
203-
EnsureChannelFirstd(keys=["image"], channel_dim="no_channel"),
204+
EnsureChannelFirstd(
205+
keys=["image"],
206+
channel_dim="no_channel",
207+
),
204208
RandSpatialCropSamplesd(
205209
keys=["image"],
206210
roi_size=(
@@ -235,7 +239,8 @@ def get_dataset_eval(self, eval_dataset_dict):
235239
[
236240
LoadImaged(keys=["image", "label"]),
237241
EnsureChannelFirstd(
238-
keys=["image", "label"], channel_dim="no_channel"
242+
keys=["image", "label"],
243+
channel_dim="no_channel",
239244
),
240245
# RandSpatialCropSamplesd(
241246
# keys=["image", "label"],
@@ -280,7 +285,10 @@ def get_dataset(self, train_transforms):
280285
load_single_images = Compose(
281286
[
282287
LoadImaged(keys=["image"]),
283-
EnsureChannelFirstd(keys=["image"]),
288+
EnsureChannelFirstd(
289+
keys=["image"],
290+
channel_dim="no_channel",
291+
),
284292
Orientationd(keys=["image"], axcodes="PLI"),
285293
SpatialPadd(
286294
keys=["image"],
@@ -1345,9 +1353,9 @@ def get_patch_loader_func(num_samples):
13451353
)
13461354
sample_loader_eval = get_patch_loader_func(num_val_samples)
13471355
else:
1348-
num_train_samples = (
1349-
num_val_samples
1350-
) = self.config.num_samples
1356+
num_train_samples = num_val_samples = (
1357+
self.config.num_samples
1358+
)
13511359

13521360
sample_loader_train = get_patch_loader_func(
13531361
num_train_samples

0 commit comments

Comments
 (0)