Skip to content

Commit

Permalink
Merge pull request #7 from TheJacksonLaboratory/cellpose_dep
Browse files Browse the repository at this point in the history
Improved cellpose model usability and dependencies
  • Loading branch information
fercer authored Aug 6, 2024
2 parents 6543033 + 3ee9049 commit 03dba5e
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 17 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ dependencies = [

[project.optional-dependencies]
cellpose = [
"cellpose-napari"
"cellpose>=3.0.0,<=3.0.10"
]
testing = [
"tox",
Expand Down
11 changes: 8 additions & 3 deletions src/napari_activelearning/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def __init__(self):

self._model = None
self._model_dropout = None

self.refresh_model = True

self._transform = None

self._pretrained_model = None
Expand Down Expand Up @@ -80,8 +83,10 @@ def _model_init(self, pretrained_model=None):
self._transform = CellposeTransform(self._channels,
self._channel_axis)

self.refresh_model = False

def _run_pred(self, img, *args, **kwargs):
if self._model is None:
if self.refresh_model:
self._model_init(pretrained_model=self._pretrained_model)

x = self._transform(img)
Expand All @@ -94,7 +99,7 @@ def _run_pred(self, img, *args, **kwargs):
return probs

def _run_eval(self, img, *args, **kwargs):
if self._model is None:
if self.refresh_model:
self._model_init(pretrained_model=self._pretrained_model)

seg, _, _ = self._model.eval(img, diameter=None,
Expand Down Expand Up @@ -162,7 +167,7 @@ def _fine_tune(self, train_data, train_labels, test_data, test_labels):
model_name=self._model_name
)

self._model_init(pretrained_model=self._pretrained_model)
self.refresh_model = True

USING_CELLPOSE = True

Expand Down
41 changes: 29 additions & 12 deletions src/napari_activelearning/_models_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,14 @@ def cellpose_segmentation_parameters(
"min": 0,
"max": 2**16}] = 2,
channels: tuple[int, int] = (0, 0),
pretrained_model: Path = Path(""),
model_type: Literal["cyto", "nuclei", "tissuenet_cp3"] = "cyto",
pretrained_model: Annotated[Path, {"widget_type": "FileEdit",
"mode": "r"}] = Path(""),
model_type: Literal["custom",
"cyto",
"cyto2",
"cyto3",
"nuclei",
"tissuenet_cp3"] = "cyto3",
gpu: bool = True
):
return dict(
Expand Down Expand Up @@ -185,7 +191,8 @@ def __init__(self):
)

for par_name in finetuning_parameter_names:
self._finetuning_parameters.__getattr__(par_name).changed.connect(
self._finetuning_parameters.__getattr__(par_name).changed\
.connect(
partial(self._set_parameter, parameter_key="_" + par_name)
)

Expand All @@ -208,12 +215,26 @@ def __init__(self):
self._finetuning_parameters_scr.hide()

def _set_parameter(self, parameter_val, parameter_key=None):
if ((isinstance(parameter_val, (str, Path)) and not parameter_val)
if (((parameter_key in {"_save_path", "_pretrained_model"})
and not parameter_val.exists())
or (isinstance(parameter_val, (int, float))
and parameter_val < 0)):
parameter_val = None

self.__setattr__(parameter_key, parameter_val)
if parameter_key == "_model_type":
if parameter_val == "custom":
self._segmentation_parameters\
.pretrained_model\
.visible = True
else:
self._segmentation_parameters\
.pretrained_model\
.visible = False
self._pretrained_model = None

if getattr(self, parameter_key) != parameter_key:
self.refresh_model = True
setattr(self, parameter_key, parameter_val)

def _show_segmentation_parameters(self, show: bool):
self._segmentation_parameters_scr.setVisible(show)
Expand Down Expand Up @@ -275,7 +296,8 @@ def __init__(self):
)

for par_name in segmentation_parameter_names:
self._segmentation_parameters.__getattr__(par_name).changed.connect(
self._segmentation_parameters.__getattr__(par_name).changed\
.connect(
partial(self._set_parameter, parameter_key="_" + par_name)
)

Expand All @@ -291,12 +313,7 @@ def __init__(self):
self._segmentation_parameters_scr.hide()

def _set_parameter(self, parameter_val, parameter_key=None):
if ((isinstance(parameter_val, (str, Path)) and not parameter_val)
or (isinstance(parameter_val, (int, float))
and parameter_val < 0)):
parameter_val = None

self.__setattr__(parameter_key, parameter_val)
setattr(self, parameter_key, parameter_val)

def _show_segmentation_parameters(self, show: bool):
self._segmentation_parameters_scr.setVisible(show)
4 changes: 3 additions & 1 deletion src/napari_activelearning/_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,12 @@ def get_acquisition_function_widget():
global CURRENT_ACQUISITION_FUNCTION

if CURRENT_ACQUISITION_FUNCTION is None:
segmentation_method = SEGMENTATION_METHOD_CLASS()

CURRENT_ACQUISITION_FUNCTION = AcquisitionFunctionWidget(
image_groups_manager=get_image_groups_manager_widget(),
labels_manager=get_label_groups_manager_widget(),
tunable_segmentation_method=SEGMENTATION_METHOD_CLASS()
tunable_segmentation_method=segmentation_method,
)

return CURRENT_ACQUISITION_FUNCTION
Expand Down

0 comments on commit 03dba5e

Please sign in to comment.