diff --git a/src/napari_activelearning/_acquisition.py b/src/napari_activelearning/_acquisition.py index 96312f2..47a745c 100644 --- a/src/napari_activelearning/_acquisition.py +++ b/src/napari_activelearning/_acquisition.py @@ -3,6 +3,7 @@ import random from pathlib import Path import numpy as np +import math import zarrdataset as zds import dask.array as da @@ -102,12 +103,19 @@ def add_multiscale_output_layer( reference_scale=reference_scale ) + is_multiscale = False + if len(output_fun_ms) > 1: + is_multiscale = True + else: + output_fun_ms = output_fun_ms[0] + func_args = dict( data=output_fun_ms, name=group_name, - multiscale=True, + multiscale=is_multiscale, opacity=0.8, scale=scale, + translate=tuple(scl / 2.0 if scl > 1 else 0 for scl in scale), blending="translucent_no_depth", ) @@ -233,8 +241,7 @@ def _fine_tune(self, train_data, train_labels, test_data, test_labels): raise NotImplementedError("This method requies to be overriden by a " "derived class.") - def fine_tune(self, dataset_metadata_list: Iterable[ - Tuple[dict, Iterable[Iterable[int]]]], + def fine_tune(self, dataset_metadata_list: Iterable[dict], train_data_proportion: float = 0.8, patch_sizes: Union[dict, int] = 256, model_axes="YXC"): @@ -245,19 +252,12 @@ def fine_tune(self, dataset_metadata_list: Iterable[ transform = self._get_transform() - for dataset_metadata, top_lefts in dataset_metadata_list: - if top_lefts is not None: - patch_sampler = StaticPatchSampler( - patch_size=patch_sizes, - top_lefts=top_lefts, - spatial_axes=dataset_metadata["labels"]["axes"] - ) - else: - patch_sampler = zds.PatchSampler( - patch_size=patch_sizes, - spatial_axes=dataset_metadata["labels"]["axes"], - min_area=0.05 - ) + for dataset_metadata in dataset_metadata_list: + patch_sampler = zds.PatchSampler( + patch_size=patch_sizes, + spatial_axes=dataset_metadata["labels"]["axes"], + min_area=0.05 + ) dataset = zds.ZarrDataset( list(dataset_metadata.values()), @@ -357,7 +357,6 @@ def _prepare_datasets_metadata( displayed_shape: Iterable[int], layer_types: Iterable[Tuple[LayersGroup, str]]): dataset_metadata = {} - sampling_positions = None for layers_group, layer_type in layer_types: if layers_group is None: @@ -366,17 +365,25 @@ def _prepare_datasets_metadata( dataset_metadata[layer_type] = layers_group.metadata dataset_metadata[layer_type]["roi"] = None + (reference_source_axes, + reference_shape) = list(zip(*[ + (ax, ax_s) + for ax, ax_s in zip(displayed_source_axes, displayed_shape) + if layer_type not in ["labels", "masks"] or ax != "C"])) + if layer_type in ["images", "labels", "masks"]: dataset_metadata[layer_type]["roi"] = [tuple( - slice(0, ax_s - ax_s % self._patch_sizes.get(ax, 1)) + slice(0, math.ceil( + lyr_s / ax_s + * (ax_s - ax_s % self._patch_sizes.get(ax, 1)) + )) if (ax != "C" and (ax in self.model_axes or ax_s > self._patch_sizes.get(ax, 1))) else slice(None) - for ax, ax_s in zip(displayed_source_axes, - displayed_shape) - if (layer_type == "images" - or (layer_type in ["labels", "masks"] and ax != "C")) + for ax, ax_s, lyr_s in zip(reference_source_axes, + reference_shape, + layers_group.shape) )] if isinstance(dataset_metadata[layer_type]["filenames"], @@ -420,14 +427,11 @@ def _prepare_datasets_metadata( labels ) - sampling_positions = list(spatial_pos) - - return dataset_metadata, sampling_positions + return dataset_metadata def compute_acquisition(self, dataset_metadata, acquisition_fun, segmentation_out, sampled_mask=None, - sampling_positions=None, segmentation_only=False): model_spatial_axes = [ ax @@ -444,7 +448,6 @@ def compute_acquisition(self, dataset_metadata, acquisition_fun, input_spatial_axes = "".join(input_spatial_axes) dl = get_dataloader(dataset_metadata, patch_size=self._patch_sizes, - sampling_positions=sampling_positions, spatial_axes=input_spatial_axes, model_input_axes=self.model_axes, shuffle=True) @@ -511,7 +514,12 @@ def compute_acquisition(self, dataset_metadata, acquisition_fun, segmentation_max = max(segmentation_max, seg_out.max()) if sampled_mask is not None: - sampled_mask[pos_u_lab] = True + scaled_pos_u_lab = tuple( + slice(pos.get(ax, 1).start // self._patch_sizes.get(ax, 1), + pos.get(ax, 1).stop // self._patch_sizes.get(ax, 1)) + for ax in input_spatial_axes + ) + sampled_mask[scaled_pos_u_lab] = True img_sampling_positions.append( LabelItem(acquisition_val, position=pos_u_lab) @@ -618,8 +626,7 @@ def compute_acquisition_layers( f"labels/{segmentation_group_name}/0" ] - (dataset_metadata, - sampling_positions) = self._prepare_datasets_metadata( + dataset_metadata = self._prepare_datasets_metadata( image_group, output_axes, displayed_source_axes, @@ -628,11 +635,19 @@ def compute_acquisition_layers( (sampling_mask_layers_group, "masks")] ) - if sampling_positions is None: + if "sampled_positions" not in segmentation_root["labels"].keys(): + (sampling_output_shape, + sampling_output_scale) = list(zip(*[ + (math.ceil(ax_s // self._patch_sizes.get(ax, 1)), + ax_scl * self._patch_sizes.get(ax, 1)) + for ax, ax_s, ax_scl in zip(output_axes, + output_shape, + output_scale)])) + sampled_root = save_zarr( output_filename, data=None, - shape=output_shape, + shape=sampling_output_shape, chunk_size=True, name="sampled_positions", dtype=bool, @@ -650,7 +665,6 @@ def compute_acquisition_layers( acquisition_fun=acquisition_fun_grp, segmentation_out=segmentation_grp, sampled_mask=sampled_grp, - sampling_positions=sampling_positions, segmentation_only=segmentation_only ) @@ -682,7 +696,7 @@ def compute_acquisition_layers( add_multiscale_output_layer( sampled_root, axes=output_axes, - scale=output_scale, + scale=sampling_output_scale, data_group="labels/sampled_positions/0", group_name=group_name + " sampled positions", layers_group_name="sampled positions", @@ -771,8 +785,7 @@ def fine_tune(self): output_axes.remove("C") output_axes = "".join(output_axes) - (dataset_metadata, - sampling_positions) = self._prepare_datasets_metadata( + dataset_metadata = self._prepare_datasets_metadata( image_group, output_axes, displayed_source_axes, @@ -780,8 +793,7 @@ def fine_tune(self): layer_types, ) - dataset_metadata_list.append((dataset_metadata, - sampling_positions)) + dataset_metadata_list.append(dataset_metadata) self.tunable_segmentation_method.fine_tune( dataset_metadata_list, diff --git a/src/napari_activelearning/_tests/test_acquisition.py b/src/napari_activelearning/_tests/test_acquisition.py index d1e2034..aa21869 100644 --- a/src/napari_activelearning/_tests/test_acquisition.py +++ b/src/napari_activelearning/_tests/test_acquisition.py @@ -54,7 +54,6 @@ def test_compute_acquisition(image_groups_manager, labels_manager, } acquisition_fun = np.zeros((1, 1, 10, 10)) segmentation_out = np.zeros((1, 1, 10, 10)) - sampling_positions = None segmentation_only = False acquisition_function.input_axes = "TZYX" @@ -78,9 +77,10 @@ def test_compute_acquisition(image_groups_manager, labels_manager, ] result = acquisition_function.compute_acquisition( - dataset_metadata, acquisition_fun, segmentation_out, - sampling_positions, - segmentation_only + dataset_metadata, + acquisition_fun=acquisition_fun, + segmentation_out=segmentation_out, + segmentation_only=segmentation_only ) assert len(result) == 1 @@ -152,16 +152,13 @@ def test_prepare_datasets_metadata(image_groups_manager, labels_manager, layer_types = [(layers_group, "images")] # Call the method - (dataset_metadata, - sampling_positions) = acquisition_function._prepare_datasets_metadata( + dataset_metadata = acquisition_function._prepare_datasets_metadata( image_group, output_axes, displayed_source_axes, displayed_shape, - layer_types - ) + layer_types) - assert sampling_positions is None expected_dataset_metadata = { "images": { "filenames": layers_group.source_data, diff --git a/src/napari_activelearning/_tests/test_utils.py b/src/napari_activelearning/_tests/test_utils.py index 09c3b72..8228021 100644 --- a/src/napari_activelearning/_tests/test_utils.py +++ b/src/napari_activelearning/_tests/test_utils.py @@ -3,6 +3,7 @@ import operator import numpy as np +import zarrdataset as zds import zarr from napari.layers._multiscale_data import MultiScaleData @@ -167,7 +168,6 @@ def test_get_basename(): def test_get_dataloader(dataset_metadata): patch_size = {"Y": 64, "X": 64} - sampling_positions = [[0, 0], [0, 64], [64, 0], [64, 64]] shuffle = True num_workers = 4 batch_size = 8 @@ -177,7 +177,6 @@ def test_get_dataloader(dataset_metadata): dataloader = get_dataloader( dataset_metadata, patch_size=patch_size, - sampling_positions=sampling_positions, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size, @@ -186,10 +185,9 @@ def test_get_dataloader(dataset_metadata): ) if USING_PYTORCH: - assert isinstance(dataloader.dataset._patch_sampler, - StaticPatchSampler) + assert isinstance(dataloader.dataset._patch_sampler, zds.PatchSampler) else: - assert isinstance(dataloader._patch_sampler, StaticPatchSampler) + assert isinstance(dataloader._patch_sampler, zds.PatchSampler) def test_compute_chunks(image_collection): diff --git a/src/napari_activelearning/_utils.py b/src/napari_activelearning/_utils.py index 90f85e0..946fa8b 100644 --- a/src/napari_activelearning/_utils.py +++ b/src/napari_activelearning/_utils.py @@ -213,7 +213,6 @@ def compute_patches(self, image_collection: zds.ImageCollection, def get_dataloader( dataset_metadata, patch_size: dict, - sampling_positions: Optional[Iterable[Iterable[int]]] = None, shuffle: bool = True, num_workers: int = 0, batch_size: int = 1, @@ -237,14 +236,9 @@ def get_dataloader( modality="superpixels" ) - if sampling_positions: - patch_sampler = StaticPatchSampler(patch_size=patch_size, - top_lefts=sampling_positions, - spatial_axes=spatial_axes) - else: - patch_sampler = zds.PatchSampler(patch_size=patch_size, - spatial_axes=spatial_axes, - min_area=0.05) + patch_sampler = zds.PatchSampler(patch_size=patch_size, + spatial_axes=spatial_axes, + min_area=0.05) train_dataset = zds.ZarrDataset( list(dataset_metadata.values()),