Skip to content

Commit

Permalink
models.layoutlmv3: Add SDPA support for LayoutLMv3 model
Browse files Browse the repository at this point in the history
  • Loading branch information
Daniel Stancl committed Dec 31, 2024
1 parent d5aebc6 commit a542ec9
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 3 deletions.
23 changes: 23 additions & 0 deletions docs/source/en/model_doc/layoutlmv3.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,29 @@ This model was contributed by [nielsr](https://huggingface.co/nielsr). The Tenso
Due to these differences in data preprocessing, one can use [`LayoutLMv3Processor`] which internally combines a [`LayoutLMv3ImageProcessor`] (for the image modality) and a [`LayoutLMv3Tokenizer`]/[`LayoutLMv3TokenizerFast`] (for the text modality) to prepare all data for the model.
- Regarding usage of [`LayoutLMv3Processor`], we refer to the [usage guide](layoutlmv2#usage-layoutlmv2processor) of its predecessor.

### Using Scaled Dot Product Attention (SDPA)

PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.

SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.

```
from transformers import LayoutLMv3Model
model = LayoutLMv3Model.from_pretrained("bert-base-uncased", torch_dtype=torch.float16, attn_implementation="sdpa")
...
```

For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).

On a local benchmark (A100-80GB, CPUx12, RAM 96.6GB, PyTorch 2.2.0, OS Ubuntu 22.04) with `float16`, we saw the
following speedups during training and inference.

## Resources

A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with LayoutLMv3. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
Expand Down
97 changes: 94 additions & 3 deletions src/transformers/models/layoutlmv3/modeling_layoutlmv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from packaging import version
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa
from ...modeling_outputs import (
BaseModelOutput,
QuestionAnsweringModelOutput,
Expand All @@ -36,6 +38,7 @@
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
get_torch_version,
logging,
replace_return_docstrings,
torch_int,
Expand Down Expand Up @@ -358,6 +361,7 @@ class LayoutLMv3PreTrainedModel(PreTrainedModel):

config_class = LayoutLMv3Config
base_model_prefix = "layoutlmv3"
_supports_sdpa = True

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down Expand Up @@ -466,6 +470,71 @@ def forward(
return outputs


class LayoutLMv3SdpaSelfAttention(LayoutLMv3SelfAttention):
def __init__(self, config: LayoutLMv3Config) -> None:
super().__init__(config)
self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0")

# Adapted from LayoutLMv3SelfAttention
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
rel_pos: Optional[torch.Tensor] = None,
rel_2d_pos: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor]:
if (
self.has_relative_attention_bias
or self.has_spatial_attention_bias
or output_attentions
or head_mask is not None
):
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented.
logger.warning_once(
"LayoutLMv3SdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
"`relative_attention_bias` or `spatial_attention_bias or `output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
"specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states,
attention_mask,
head_mask,
output_attentions,
rel_pos,
rel_2d_pos,
)

batch_size, seq_len = hidden_states.shape[:2]

query_layer = self.transpose_for_scores(self.query(hidden_states))
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))

# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
# attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
# Reference: https://github.com/pytorch/pytorch/issues/112577
if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None:
query_layer = query_layer.contiguous()
key_layer = key_layer.contiguous()
value_layer = value_layer.contiguous()

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attn_mask=attention_mask,
)

attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(batch_size, seq_len, self.all_head_size)
attn_output = self.dropout(attn_output)

return (attn_output,)


# Copied from transformers.models.roberta.modeling_roberta.RobertaSelfOutput
class LayoutLMv3SelfOutput(nn.Module):
def __init__(self, config):
Expand All @@ -481,10 +550,16 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to
return hidden_states


# Copied from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Attention with LayoutLMv2->LayoutLMv3
LAYOUTLMV3_SELF_ATTENTION_CLASSES = {
"eager": LayoutLMv3SelfAttention,
"sdpa": LayoutLMv3SdpaSelfAttention,
}


class LayoutLMv3Attention(nn.Module):
def __init__(self, config):
super().__init__()
self.self = LAYOUTLMV3_SELF_ATTENTION_CLASSES[config._attn_implementation](config)
self.self = LayoutLMv3SelfAttention(config)
self.output = LayoutLMv3SelfOutput(config)

Expand Down Expand Up @@ -774,6 +849,8 @@ def __init__(self, config):

self.encoder = LayoutLMv3Encoder(config)

self.attn_implementation = config._attn_implementation

self.init_weights()

def get_input_embeddings(self):
Expand Down Expand Up @@ -961,10 +1038,24 @@ def forward(
position_ids = position_ids.expand_as(input_ids)
final_position_ids = position_ids

extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
attention_mask, None, device, dtype=embedding_output.dtype
use_sdpa_attention_mask = (
self.attn_implementation == "sdpa"
and not self.config.has_relative_attention_bias
and not self.config.has_spatial_attention_bias
and head_mask is None
and not output_attentions
)

# Expand the attention mask
if use_sdpa_attention_mask and attention_mask.dim() == 2:
extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
attention_mask, embedding_output.dtype, tgt_len=seq_length
)
else:
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
attention_mask, None, device, dtype=embedding_output.dtype
)

# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
Expand Down
23 changes: 23 additions & 0 deletions tests/models/layoutlmv3/test_modeling_layoutlmv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import copy
import unittest

from transformers import AutoTokenizer
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, slow, torch_device
from transformers.utils import cached_property, is_torch_available, is_vision_available
Expand Down Expand Up @@ -378,6 +379,28 @@ def test_model_from_pretrained(self):
model = LayoutLMv3Model.from_pretrained(model_name)
self.assertIsNotNone(model)

def test_sdpa(self):
model = LayoutLMv3Model.from_pretrained(
"hf-tiny-model-private/tiny-random-LayoutLMv3Model", attn_implementation="eager"
)
model_sdpa = LayoutLMv3Model.from_pretrained(
"hf-tiny-model-private/tiny-random-LayoutLMv3Model", attn_implementation="sdpa"
)

model = model.eval()
model_sdpa = model_sdpa.eval()

tokenizer = AutoTokenizer.from_pretrained("hf-tiny-model-private/tiny-random-LayoutLMv3Model")
words = "I am in Paris and".split()
inputs = tokenizer(text=[words], boxes=[[(0, 0, 1, 1)] * len(words)], return_tensors="pt")

with torch.no_grad():
res_eager = model(**inputs)
res_sdpa = model_sdpa(**inputs)
self.assertTrue(
torch.allclose(res_eager.last_hidden_state, res_sdpa.last_hidden_state, atol=1e-5, rtol=1e-4)
)


# We will verify our results on an image of cute cats
def prepare_img():
Expand Down

0 comments on commit a542ec9

Please sign in to comment.