diff --git a/src/transformers/models/llama4/image_processing_llama4_fast.py b/src/transformers/models/llama4/image_processing_llama4_fast.py index ef44786f7c66..de42ea5ca321 100644 --- a/src/transformers/models/llama4/image_processing_llama4_fast.py +++ b/src/transformers/models/llama4/image_processing_llama4_fast.py @@ -244,38 +244,42 @@ def get_best_fit( original_height, original_width = image_size # get all possible resolutions heights/widths - target_heights, target_widths = ( - possible_resolutions[:, 0], - possible_resolutions[:, 1], - ) + target_heights = possible_resolutions[:, 0] + target_widths = possible_resolutions[:, 1] + + # get scaling factors to resize the image without distortion # get scaling factors to resize the image without distortion scale_w = target_widths / original_width scale_h = target_heights / original_height # get the min scale between width and height (limiting side -> no distortion) - scales = torch.where(scale_h > scale_w, scale_w, scale_h) + scales = torch.minimum(scale_h, scale_w) # slightly faster than torch.where for simple min # filter only scales that allow upscaling - upscaling_options = scales[scales >= 1] - if len(upscaling_options) > 0: + upscaling_mask = scales >= 1 + upscaling_options = scales[upscaling_mask] + + if upscaling_options.numel() > 0: if resize_to_max_canvas: - selected_scale = torch.max(upscaling_options) + selected_scale = upscaling_options.max() else: - selected_scale = torch.min(upscaling_options) + selected_scale = upscaling_options.min() else: # no upscaling possible, # get the minimum downscaling (max scale for scales<1) - downscaling_options = scales[scales < 1] - selected_scale = torch.max(downscaling_options) + downscaling_mask = scales < 1 + downscaling_options = scales[downscaling_mask] + selected_scale = downscaling_options.max() # get all resolutions that support this scaling factor, # e.g. you can upscale to 224x224, 224x448, 224x672 without distortion - chosen_canvas = possible_resolutions[scales == selected_scale] + chosen_mask = scales == selected_scale + chosen_canvas = possible_resolutions[chosen_mask] # if there are multiple resolutions, # get the one with minimum area to reduce padding - if len(chosen_canvas) > 1: + if chosen_canvas.size(0) > 1: areas = chosen_canvas[:, 0] * chosen_canvas[:, 1] optimal_idx = torch.argmin(areas) optimal_canvas = chosen_canvas[optimal_idx]