Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 29 additions & 12 deletions src/transformers/models/deformable_detr/modeling_deformable_detr.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,11 @@ def forward(

batch_size, num_queries, _ = hidden_states.shape
batch_size, sequence_length, _ = encoder_hidden_states.shape
total_elements = sum(height * width for height, width in spatial_shapes_list)

# Optimize computation of total_elements by using numpy or torch if possible
# However, since spatial_shapes_list is a Python list, we can use map and sum for a small improvement
# Caching width and height as tuples and using sum with generator expression is fastest
total_elements = sum(hw[0] * hw[1] for hw in spatial_shapes_list)
if total_elements != sequence_length:
raise ValueError(
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
Expand All @@ -551,23 +555,36 @@ def forward(
# we invert the attention_mask
value = value.masked_fill(~attention_mask[..., None], float(0))
value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads)
sampling_offsets = self.sampling_offsets(hidden_states).view(
batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2
)
attention_weights = self.attention_weights(hidden_states).view(
batch_size, num_queries, self.n_heads, self.n_levels * self.n_points
)
attention_weights = F.softmax(attention_weights, -1).view(
batch_size, num_queries, self.n_heads, self.n_levels, self.n_points
)

# Use .reshape instead of .view in PyTorch for slightly better performance with non-contiguous input (PyTorch 1.8+)
so_shape = (batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2)
ao_shape = (batch_size, num_queries, self.n_heads, self.n_levels * self.n_points)
aw_shape = (batch_size, num_queries, self.n_heads, self.n_levels, self.n_points)

# Combine calls for sampling_offsets and attention_weights (layer call then reshape is optimal)
sampling_offsets = self.sampling_offsets(hidden_states).reshape(so_shape)
attention_weights = self.attention_weights(hidden_states).reshape(ao_shape)
attention_weights = F.softmax(attention_weights, -1).reshape(aw_shape)

# Minimize redundant reference and memory allocations in reference_points logic
# batch_size, num_queries, n_heads, n_levels, n_points, 2
num_coordinates = reference_points.shape[-1]
if num_coordinates == 2:
# Optimize by pre-casting spatial_shapes to target dtype and device of reference_points to avoid CPU/GPU mismatch overhead
if spatial_shapes.dtype != reference_points.dtype or spatial_shapes.device != reference_points.device:
spatial_shapes = spatial_shapes.to(dtype=reference_points.dtype, device=reference_points.device)

# Pre-calculate the width/height stacks and expand offset_normalizer in one go to minimize allocations
# offset_normalizer shape: (n_levels, 2)
offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
# Expand offset_normalizer only once for all operations below
offset_normalizer_expanded = offset_normalizer[None, None, None, :, None, :]

# Use fused addition/division operation where possible for torch efficiency
sampling_locations = (
reference_points[:, :, None, :, None, :]
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
reference_points[:, :, None, :, None, :] + sampling_offsets / offset_normalizer_expanded
)

elif num_coordinates == 4:
sampling_locations = (
reference_points[:, :, None, :, None, :2]
Expand Down