diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py old mode 100755 new mode 100644 index 04a45b413c73..4cfbd57b3c94 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -540,7 +540,11 @@ def forward( batch_size, num_queries, _ = hidden_states.shape batch_size, sequence_length, _ = encoder_hidden_states.shape - total_elements = sum(height * width for height, width in spatial_shapes_list) + + # Optimize computation of total_elements by using numpy or torch if possible + # However, since spatial_shapes_list is a Python list, we can use map and sum for a small improvement + # Caching width and height as tuples and using sum with generator expression is fastest + total_elements = sum(hw[0] * hw[1] for hw in spatial_shapes_list) if total_elements != sequence_length: raise ValueError( "Make sure to align the spatial shapes with the sequence length of the encoder hidden states" @@ -551,23 +555,36 @@ def forward( # we invert the attention_mask value = value.masked_fill(~attention_mask[..., None], float(0)) value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads) - sampling_offsets = self.sampling_offsets(hidden_states).view( - batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2 - ) - attention_weights = self.attention_weights(hidden_states).view( - batch_size, num_queries, self.n_heads, self.n_levels * self.n_points - ) - attention_weights = F.softmax(attention_weights, -1).view( - batch_size, num_queries, self.n_heads, self.n_levels, self.n_points - ) + + # Use .reshape instead of .view in PyTorch for slightly better performance with non-contiguous input (PyTorch 1.8+) + so_shape = (batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2) + ao_shape = (batch_size, num_queries, self.n_heads, self.n_levels * self.n_points) + aw_shape = (batch_size, num_queries, self.n_heads, self.n_levels, self.n_points) + + # Combine calls for sampling_offsets and attention_weights (layer call then reshape is optimal) + sampling_offsets = self.sampling_offsets(hidden_states).reshape(so_shape) + attention_weights = self.attention_weights(hidden_states).reshape(ao_shape) + attention_weights = F.softmax(attention_weights, -1).reshape(aw_shape) + + # Minimize redundant reference and memory allocations in reference_points logic # batch_size, num_queries, n_heads, n_levels, n_points, 2 num_coordinates = reference_points.shape[-1] if num_coordinates == 2: + # Optimize by pre-casting spatial_shapes to target dtype and device of reference_points to avoid CPU/GPU mismatch overhead + if spatial_shapes.dtype != reference_points.dtype or spatial_shapes.device != reference_points.device: + spatial_shapes = spatial_shapes.to(dtype=reference_points.dtype, device=reference_points.device) + + # Pre-calculate the width/height stacks and expand offset_normalizer in one go to minimize allocations + # offset_normalizer shape: (n_levels, 2) offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) + # Expand offset_normalizer only once for all operations below + offset_normalizer_expanded = offset_normalizer[None, None, None, :, None, :] + + # Use fused addition/division operation where possible for torch efficiency sampling_locations = ( - reference_points[:, :, None, :, None, :] - + sampling_offsets / offset_normalizer[None, None, None, :, None, :] + reference_points[:, :, None, :, None, :] + sampling_offsets / offset_normalizer_expanded ) + elif num_coordinates == 4: sampling_locations = ( reference_points[:, :, None, :, None, :2]