Skip to content

Commit

Permalink
uniformize kwargs for SAM (#34578)
Browse files Browse the repository at this point in the history
* Make kwargs uniform for SAM

* Remove unused attribute

* Make point_pad_value part of image_kwargs

* Update annotations

* Code review - use existing methods

* Use ProcessorTesterMixin

* Do not add ProcessorTesterMixin everywhere
  • Loading branch information
tibor-reiss authored Dec 23, 2024
1 parent 2bb6098 commit e10be82
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 29 deletions.
80 changes: 60 additions & 20 deletions src/transformers/models/sam/processing_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
"""

from copy import deepcopy
from typing import Optional, Union
from typing import List, Optional, Union

import numpy as np

from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding
from ...utils import TensorType, is_tf_available, is_torch_available
from ...image_utils import ImageInput, VideoInput
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin
from ...tokenization_utils_base import AudioInput, BatchEncoding, PreTokenizedInput, TextInput
from ...utils import is_tf_available, is_torch_available


if is_torch_available():
Expand All @@ -33,6 +34,23 @@
import tensorflow as tf


class SamImagesKwargs(ImagesKwargs):
segmentation_maps: Optional[ImageInput]
input_points: Optional[List[List[float]]]
input_labels: Optional[List[List[int]]]
input_boxes: Optional[List[List[List[float]]]]
point_pad_value: Optional[int]


class SamProcessorKwargs(ProcessingKwargs, total=False):
images_kwargs: SamImagesKwargs
_defaults = {
"images_kwargs": {
"point_pad_value": -10,
}
}


class SamProcessor(ProcessorMixin):
r"""
Constructs a SAM processor which wraps a SAM image processor and an 2D points & Bounding boxes processor into a
Expand All @@ -48,32 +66,50 @@ class SamProcessor(ProcessorMixin):

attributes = ["image_processor"]
image_processor_class = "SamImageProcessor"
# For backward compatibility. See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details.
optional_call_args = [
"segmentation_maps",
"input_points",
"input_labels",
"input_boxes",
]

def __init__(self, image_processor):
super().__init__(image_processor)
self.current_processor = self.image_processor
self.point_pad_value = -10
self.target_size = self.image_processor.size["longest_edge"]

def __call__(
self,
images=None,
segmentation_maps=None,
input_points=None,
input_labels=None,
input_boxes=None,
return_tensors: Optional[Union[str, TensorType]] = None,
images: Optional[ImageInput] = None,
# The following is to capture `segmentation_maps`, `input_points`, `input_labels` and `input_boxes`
# arguments that may be passed as a positional argument.
# See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details,
# or this conversation for more context:
# https://github.com/huggingface/transformers/pull/32544#discussion_r1720208116
# This behavior is only needed for backward compatibility and will be removed in future versions.
*args, # to be deprecated
text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
audio: Optional[AudioInput] = None,
video: Optional[VideoInput] = None,
**kwargs,
) -> BatchEncoding:
"""
This method uses [`SamImageProcessor.__call__`] method to prepare image(s) for the model. It also prepares 2D
points and bounding boxes for the model if they are provided.
"""
output_kwargs = self._merge_kwargs(
SamProcessorKwargs,
tokenizer_init_kwargs={},
**kwargs,
**self.prepare_and_validate_optional_call_args(*args),
)
input_points = output_kwargs["images_kwargs"].pop("input_points", None)
input_labels = output_kwargs["images_kwargs"].pop("input_labels", None)
input_boxes = output_kwargs["images_kwargs"].pop("input_boxes", None)

encoding_image_processor = self.image_processor(
images,
segmentation_maps=segmentation_maps,
return_tensors=return_tensors,
**kwargs,
**output_kwargs["images_kwargs"],
)

# pop arguments that are not used in the foward but used nevertheless
Expand All @@ -94,7 +130,8 @@ def __call__(
input_points=input_points,
input_labels=input_labels,
input_boxes=input_boxes,
return_tensors=return_tensors,
return_tensors=output_kwargs["common_kwargs"].get("return_tensors"),
point_pad_value=output_kwargs["images_kwargs"].get("point_pad_value"),
)

return encoding_image_processor
Expand All @@ -107,6 +144,7 @@ def _normalize_and_convert(
input_labels=None,
input_boxes=None,
return_tensors="pt",
point_pad_value=-10,
):
if input_points is not None:
if len(original_sizes) != len(input_points):
Expand All @@ -121,7 +159,9 @@ def _normalize_and_convert(
# check that all arrays have the same shape
if not all(point.shape == input_points[0].shape for point in input_points):
if input_labels is not None:
input_points, input_labels = self._pad_points_and_labels(input_points, input_labels)
input_points, input_labels = self._pad_points_and_labels(
input_points, input_labels, point_pad_value
)

input_points = np.array(input_points)

Expand Down Expand Up @@ -174,7 +214,7 @@ def _normalize_and_convert(

return encoding_image_processor

def _pad_points_and_labels(self, input_points, input_labels):
def _pad_points_and_labels(self, input_points, input_labels, point_pad_value):
r"""
The method pads the 2D points and labels to the maximum number of points in the batch.
"""
Expand All @@ -183,9 +223,9 @@ def _pad_points_and_labels(self, input_points, input_labels):
for i, point in enumerate(input_points):
if point.shape[0] != expected_nb_points:
point = np.concatenate(
[point, np.zeros((expected_nb_points - point.shape[0], 2)) + self.point_pad_value], axis=0
[point, np.zeros((expected_nb_points - point.shape[0], 2)) + point_pad_value], axis=0
)
input_labels[i] = np.append(input_labels[i], [self.point_pad_value])
input_labels[i] = np.append(input_labels[i], [point_pad_value])
processed_input_points.append(point)
input_points = processed_input_points
return input_points, input_labels
Expand Down
30 changes: 21 additions & 9 deletions tests/models/sam/test_processor_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
)
from transformers.utils import is_tf_available, is_torch_available, is_vision_available

from ...test_processing_common import prepare_image_inputs
from ...test_processing_common import ProcessorTesterMixin, prepare_image_inputs


if is_vision_available():
Expand All @@ -43,7 +43,9 @@

@require_vision
@require_torchvision
class SamProcessorTest(unittest.TestCase):
class SamProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = SamProcessor

def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
image_processor = SamImageProcessor()
Expand All @@ -56,11 +58,6 @@ def get_image_processor(self, **kwargs):
def tearDown(self):
shutil.rmtree(self.tmpdirname)

# Processor tester class can't use ProcessorTesterMixin atm because the processor is atypical e.g. only contains an image processor
def prepare_image_inputs(self):
"""This function prepares a list of PIL images."""
return prepare_image_inputs()

def prepare_mask_inputs(self):
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
or a list of PyTorch tensors if one specifies torchify=True.
Expand All @@ -69,6 +66,21 @@ def prepare_mask_inputs(self):
mask_inputs = [Image.fromarray(x) for x in mask_inputs]
return mask_inputs

def test_chat_template_save_loading(self):
self.skipTest("SamProcessor does not have a tokenizer")

def test_image_processor_defaults_preserved_by_image_kwargs(self):
self.skipTest("SamProcessor does not have a tokenizer")

def test_kwargs_overrides_default_image_processor_kwargs(self):
self.skipTest("SamProcessor does not have a tokenizer")

def test_kwargs_overrides_default_tokenizer_kwargs(self):
self.skipTest("SamProcessor does not have a tokenizer")

def test_tokenizer_defaults_preserved_by_kwargs(self):
self.skipTest("SamProcessor does not have a tokenizer")

def test_save_load_pretrained_additional_features(self):
processor = SamProcessor(image_processor=self.get_image_processor())
processor.save_pretrained(self.tmpdirname)
Expand Down Expand Up @@ -165,7 +177,7 @@ def get_image_processor(self, **kwargs):
def tearDown(self):
shutil.rmtree(self.tmpdirname)

# Processor tester class can't use ProcessorTesterMixin as processor is atypical e.g. only contains an image processor and it assumes torch
# This is to avoid repeating the skipping of the common tests
def prepare_image_inputs(self):
"""This function prepares a list of PIL images."""
return prepare_image_inputs()
Expand Down Expand Up @@ -248,7 +260,7 @@ def get_image_processor(self, **kwargs):
def tearDown(self):
shutil.rmtree(self.tmpdirname)

# Processor tester class can't use ProcessorTesterMixin atm because the processor is atypical e.g. only contains an image processor
# This is to avoid repeating the skipping of the common tests
def prepare_image_inputs(self):
"""This function prepares a list of PIL images."""
return prepare_image_inputs()
Expand Down

0 comments on commit e10be82

Please sign in to comment.