Skip to content

Commit

Permalink
Merge pull request #27 from TheJacksonLaboratory/dev
Browse files Browse the repository at this point in the history
Improvement to sampling positions management
  • Loading branch information
fercer authored Nov 29, 2024
2 parents d057b46 + 697b918 commit f6c505c
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 61 deletions.
88 changes: 50 additions & 38 deletions src/napari_activelearning/_acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import random
from pathlib import Path
import numpy as np
import math

import zarrdataset as zds
import dask.array as da
Expand Down Expand Up @@ -102,12 +103,19 @@ def add_multiscale_output_layer(
reference_scale=reference_scale
)

is_multiscale = False
if len(output_fun_ms) > 1:
is_multiscale = True
else:
output_fun_ms = output_fun_ms[0]

func_args = dict(
data=output_fun_ms,
name=group_name,
multiscale=True,
multiscale=is_multiscale,
opacity=0.8,
scale=scale,
translate=tuple(scl / 2.0 if scl > 1 else 0 for scl in scale),
blending="translucent_no_depth",
)

Expand Down Expand Up @@ -233,8 +241,7 @@ def _fine_tune(self, train_data, train_labels, test_data, test_labels):
raise NotImplementedError("This method requies to be overriden by a "
"derived class.")

def fine_tune(self, dataset_metadata_list: Iterable[
Tuple[dict, Iterable[Iterable[int]]]],
def fine_tune(self, dataset_metadata_list: Iterable[dict],
train_data_proportion: float = 0.8,
patch_sizes: Union[dict, int] = 256,
model_axes="YXC"):
Expand All @@ -245,19 +252,12 @@ def fine_tune(self, dataset_metadata_list: Iterable[

transform = self._get_transform()

for dataset_metadata, top_lefts in dataset_metadata_list:
if top_lefts is not None:
patch_sampler = StaticPatchSampler(
patch_size=patch_sizes,
top_lefts=top_lefts,
spatial_axes=dataset_metadata["labels"]["axes"]
)
else:
patch_sampler = zds.PatchSampler(
patch_size=patch_sizes,
spatial_axes=dataset_metadata["labels"]["axes"],
min_area=0.05
)
for dataset_metadata in dataset_metadata_list:
patch_sampler = zds.PatchSampler(
patch_size=patch_sizes,
spatial_axes=dataset_metadata["labels"]["axes"],
min_area=0.05
)

dataset = zds.ZarrDataset(
list(dataset_metadata.values()),
Expand Down Expand Up @@ -357,7 +357,6 @@ def _prepare_datasets_metadata(
displayed_shape: Iterable[int],
layer_types: Iterable[Tuple[LayersGroup, str]]):
dataset_metadata = {}
sampling_positions = None

for layers_group, layer_type in layer_types:
if layers_group is None:
Expand All @@ -366,17 +365,25 @@ def _prepare_datasets_metadata(
dataset_metadata[layer_type] = layers_group.metadata
dataset_metadata[layer_type]["roi"] = None

(reference_source_axes,
reference_shape) = list(zip(*[
(ax, ax_s)
for ax, ax_s in zip(displayed_source_axes, displayed_shape)
if layer_type not in ["labels", "masks"] or ax != "C"]))

if layer_type in ["images", "labels", "masks"]:
dataset_metadata[layer_type]["roi"] = [tuple(
slice(0, ax_s - ax_s % self._patch_sizes.get(ax, 1))
slice(0, math.ceil(
lyr_s / ax_s
* (ax_s - ax_s % self._patch_sizes.get(ax, 1))
))
if (ax != "C"
and (ax in self.model_axes
or ax_s > self._patch_sizes.get(ax, 1)))
else slice(None)
for ax, ax_s in zip(displayed_source_axes,
displayed_shape)
if (layer_type == "images"
or (layer_type in ["labels", "masks"] and ax != "C"))
for ax, ax_s, lyr_s in zip(reference_source_axes,
reference_shape,
layers_group.shape)
)]

if isinstance(dataset_metadata[layer_type]["filenames"],
Expand Down Expand Up @@ -420,14 +427,11 @@ def _prepare_datasets_metadata(
labels
)

sampling_positions = list(spatial_pos)

return dataset_metadata, sampling_positions
return dataset_metadata

def compute_acquisition(self, dataset_metadata, acquisition_fun,
segmentation_out,
sampled_mask=None,
sampling_positions=None,
segmentation_only=False):
model_spatial_axes = [
ax
Expand All @@ -444,7 +448,6 @@ def compute_acquisition(self, dataset_metadata, acquisition_fun,
input_spatial_axes = "".join(input_spatial_axes)

dl = get_dataloader(dataset_metadata, patch_size=self._patch_sizes,
sampling_positions=sampling_positions,
spatial_axes=input_spatial_axes,
model_input_axes=self.model_axes,
shuffle=True)
Expand Down Expand Up @@ -511,7 +514,12 @@ def compute_acquisition(self, dataset_metadata, acquisition_fun,
segmentation_max = max(segmentation_max, seg_out.max())

if sampled_mask is not None:
sampled_mask[pos_u_lab] = True
scaled_pos_u_lab = tuple(
slice(pos.get(ax, 1).start // self._patch_sizes.get(ax, 1),
pos.get(ax, 1).stop // self._patch_sizes.get(ax, 1))
for ax in input_spatial_axes
)
sampled_mask[scaled_pos_u_lab] = True

img_sampling_positions.append(
LabelItem(acquisition_val, position=pos_u_lab)
Expand Down Expand Up @@ -618,8 +626,7 @@ def compute_acquisition_layers(
f"labels/{segmentation_group_name}/0"
]

(dataset_metadata,
sampling_positions) = self._prepare_datasets_metadata(
dataset_metadata = self._prepare_datasets_metadata(
image_group,
output_axes,
displayed_source_axes,
Expand All @@ -628,11 +635,19 @@ def compute_acquisition_layers(
(sampling_mask_layers_group, "masks")]
)

if sampling_positions is None:
if "sampled_positions" not in segmentation_root["labels"].keys():
(sampling_output_shape,
sampling_output_scale) = list(zip(*[
(math.ceil(ax_s // self._patch_sizes.get(ax, 1)),
ax_scl * self._patch_sizes.get(ax, 1))
for ax, ax_s, ax_scl in zip(output_axes,
output_shape,
output_scale)]))

sampled_root = save_zarr(
output_filename,
data=None,
shape=output_shape,
shape=sampling_output_shape,
chunk_size=True,
name="sampled_positions",
dtype=bool,
Expand All @@ -650,7 +665,6 @@ def compute_acquisition_layers(
acquisition_fun=acquisition_fun_grp,
segmentation_out=segmentation_grp,
sampled_mask=sampled_grp,
sampling_positions=sampling_positions,
segmentation_only=segmentation_only
)

Expand Down Expand Up @@ -682,7 +696,7 @@ def compute_acquisition_layers(
add_multiscale_output_layer(
sampled_root,
axes=output_axes,
scale=output_scale,
scale=sampling_output_scale,
data_group="labels/sampled_positions/0",
group_name=group_name + " sampled positions",
layers_group_name="sampled positions",
Expand Down Expand Up @@ -771,17 +785,15 @@ def fine_tune(self):
output_axes.remove("C")
output_axes = "".join(output_axes)

(dataset_metadata,
sampling_positions) = self._prepare_datasets_metadata(
dataset_metadata = self._prepare_datasets_metadata(
image_group,
output_axes,
displayed_source_axes,
displayed_shape,
layer_types,
)

dataset_metadata_list.append((dataset_metadata,
sampling_positions))
dataset_metadata_list.append(dataset_metadata)

self.tunable_segmentation_method.fine_tune(
dataset_metadata_list,
Expand Down
15 changes: 6 additions & 9 deletions src/napari_activelearning/_tests/test_acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def test_compute_acquisition(image_groups_manager, labels_manager,
}
acquisition_fun = np.zeros((1, 1, 10, 10))
segmentation_out = np.zeros((1, 1, 10, 10))
sampling_positions = None
segmentation_only = False

acquisition_function.input_axes = "TZYX"
Expand All @@ -78,9 +77,10 @@ def test_compute_acquisition(image_groups_manager, labels_manager,
]

result = acquisition_function.compute_acquisition(
dataset_metadata, acquisition_fun, segmentation_out,
sampling_positions,
segmentation_only
dataset_metadata,
acquisition_fun=acquisition_fun,
segmentation_out=segmentation_out,
segmentation_only=segmentation_only
)

assert len(result) == 1
Expand Down Expand Up @@ -152,16 +152,13 @@ def test_prepare_datasets_metadata(image_groups_manager, labels_manager,
layer_types = [(layers_group, "images")]

# Call the method
(dataset_metadata,
sampling_positions) = acquisition_function._prepare_datasets_metadata(
dataset_metadata = acquisition_function._prepare_datasets_metadata(
image_group,
output_axes,
displayed_source_axes,
displayed_shape,
layer_types
)
layer_types)

assert sampling_positions is None
expected_dataset_metadata = {
"images": {
"filenames": layers_group.source_data,
Expand Down
8 changes: 3 additions & 5 deletions src/napari_activelearning/_tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import operator

import numpy as np
import zarrdataset as zds
import zarr

from napari.layers._multiscale_data import MultiScaleData
Expand Down Expand Up @@ -167,7 +168,6 @@ def test_get_basename():

def test_get_dataloader(dataset_metadata):
patch_size = {"Y": 64, "X": 64}
sampling_positions = [[0, 0], [0, 64], [64, 0], [64, 64]]
shuffle = True
num_workers = 4
batch_size = 8
Expand All @@ -177,7 +177,6 @@ def test_get_dataloader(dataset_metadata):
dataloader = get_dataloader(
dataset_metadata,
patch_size=patch_size,
sampling_positions=sampling_positions,
shuffle=shuffle,
num_workers=num_workers,
batch_size=batch_size,
Expand All @@ -186,10 +185,9 @@ def test_get_dataloader(dataset_metadata):
)

if USING_PYTORCH:
assert isinstance(dataloader.dataset._patch_sampler,
StaticPatchSampler)
assert isinstance(dataloader.dataset._patch_sampler, zds.PatchSampler)
else:
assert isinstance(dataloader._patch_sampler, StaticPatchSampler)
assert isinstance(dataloader._patch_sampler, zds.PatchSampler)


def test_compute_chunks(image_collection):
Expand Down
12 changes: 3 additions & 9 deletions src/napari_activelearning/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,6 @@ def compute_patches(self, image_collection: zds.ImageCollection,

def get_dataloader(
dataset_metadata, patch_size: dict,
sampling_positions: Optional[Iterable[Iterable[int]]] = None,
shuffle: bool = True,
num_workers: int = 0,
batch_size: int = 1,
Expand All @@ -237,14 +236,9 @@ def get_dataloader(
modality="superpixels"
)

if sampling_positions:
patch_sampler = StaticPatchSampler(patch_size=patch_size,
top_lefts=sampling_positions,
spatial_axes=spatial_axes)
else:
patch_sampler = zds.PatchSampler(patch_size=patch_size,
spatial_axes=spatial_axes,
min_area=0.05)
patch_sampler = zds.PatchSampler(patch_size=patch_size,
spatial_axes=spatial_axes,
min_area=0.05)

train_dataset = zds.ZarrDataset(
list(dataset_metadata.values()),
Expand Down

0 comments on commit f6c505c

Please sign in to comment.