diff --git a/label_anything/sam/mmdetection.py b/label_anything/sam/mmdetection.py index 0c30271..74a13ba 100644 --- a/label_anything/sam/mmdetection.py +++ b/label_anything/sam/mmdetection.py @@ -62,6 +62,7 @@ def load_my_model( class MMDetection(LabelStudioMLBase): """Object detector based on https://github.com/open-mmlab/mmdetection.""" + _predictor = None def __init__(self, model_name="sam_hq", config_file=None, @@ -78,10 +79,13 @@ def __init__(self, **kwargs): super(MMDetection, self).__init__(**kwargs) - - PREDICTOR = load_my_model( - model_name, device, sam_config, sam_checkpoint_file) - self.PREDICTOR = PREDICTOR + + # Only load the model if it hasn't been loaded before + if MMDetection._predictor is None: + MMDetection._predictor = load_my_model( + model_name, device, sam_config, sam_checkpoint_file) + + self.PREDICTOR = MMDetection._predictor self.out_mask = out_mask self.out_bbox = out_bbox