From f1f8c9a1280d5567339e86cb8d6fcd23d31d8067 Mon Sep 17 00:00:00 2001 From: ankitade Date: Mon, 20 Jun 2022 02:45:10 +0000 Subject: [PATCH 1/3] [FLAVA] Separate out text and image encoders [ghstack-poisoned] --- test/models/flava/__init__.py | 5 + test/models/{ => flava}/test_flava.py | 0 test/models/flava/test_flava_image_encoder.py | 167 ++++++ test/models/flava/test_flava_text_encoder.py | 97 ++++ .../models/flava/flava_image_encoder.py | 304 ++++++++++ torchmultimodal/models/flava/flava_model.py | 532 +----------------- .../models/flava/flava_text_encoder.py | 264 +++++++++ 7 files changed, 840 insertions(+), 529 deletions(-) create mode 100644 test/models/flava/__init__.py rename test/models/{ => flava}/test_flava.py (100%) create mode 100644 test/models/flava/test_flava_image_encoder.py create mode 100644 test/models/flava/test_flava_text_encoder.py create mode 100644 torchmultimodal/models/flava/flava_image_encoder.py create mode 100644 torchmultimodal/models/flava/flava_text_encoder.py diff --git a/test/models/flava/__init__.py b/test/models/flava/__init__.py new file mode 100644 index 00000000..2e41cd71 --- /dev/null +++ b/test/models/flava/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/test/models/test_flava.py b/test/models/flava/test_flava.py similarity index 100% rename from test/models/test_flava.py rename to test/models/flava/test_flava.py diff --git a/test/models/flava/test_flava_image_encoder.py b/test/models/flava/test_flava_image_encoder.py new file mode 100644 index 00000000..453b542b --- /dev/null +++ b/test/models/flava/test_flava_image_encoder.py @@ -0,0 +1,167 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from test.test_utils import assert_expected, set_rng_seed +from torch import nn +from torchmultimodal.models.flava.flava_image_encoder import ( + ImageEmbeddings, + ImageTransformer, +) +from torchmultimodal.modules.layers.transformer import FLAVATransformerEncoder + + +class TestFlavaImageEncoder(unittest.TestCase): + def setUp(self): + set_rng_seed(0) + torch.manual_seed(0) + self.image_embedding = ImageEmbeddings( + image_size=2, patch_size=1, hidden_size=2 + ) + + encoder = FLAVATransformerEncoder( + hidden_size=2, + num_attention_heads=1, + num_hidden_layers=1, + hidden_dropout_prob=0.0, + intermediate_size=1, + attention_probs_dropout_prob=0.0, + ) + self.image_encoder = ImageTransformer( + embeddings=self.image_embedding, + encoder=encoder, + layernorm=nn.LayerNorm(2), + pooler=nn.Identity(), + ) + + def test_embedding(self): + input = torch.ones(2, 3, 2, 2) + out = self.image_embedding(input) + assert_expected( + out, + torch.Tensor( + [ + [ + [0.0000, 0.0000], + [0.0224, 0.0573], + [0.0224, 0.0573], + [0.0224, 0.0573], + [0.0224, 0.0573], + ], + [ + [0.0000, 0.0000], + [0.0224, 0.0573], + [0.0224, 0.0573], + [0.0224, 0.0573], + [0.0224, 0.0573], + ], + ] + ), + atol=1e-4, + rtol=0, + ) + + def test_image_encoder(self): + input = torch.ones(2, 3, 2, 2) + out = self.image_encoder(input) + assert_expected( + out.last_hidden_state, + torch.Tensor( + [ + [ + [-0.0040, 0.0040], + [-0.9840, 0.9840], + [-0.9840, 0.9840], + [-0.9840, 0.9840], + [-0.9840, 0.9840], + ], + [ + [-0.0040, 0.0040], + [-0.9840, 0.9840], + [-0.9840, 0.9840], + [-0.9840, 0.9840], + [-0.9840, 0.9840], + ], + ] + ), + atol=1e-4, + rtol=0, + ) + assert_expected(out.pooler_output, out.last_hidden_state) + assert_expected( + out.hidden_states, + ( + torch.Tensor( + [ + [ + [0.0000, 0.0000], + [0.0224, 0.0573], + [0.0224, 0.0573], + [0.0224, 0.0573], + [0.0224, 0.0573], + ], + [ + [0.0000, 0.0000], + [0.0224, 0.0573], + [0.0224, 0.0573], + [0.0224, 0.0573], + [0.0224, 0.0573], + ], + ] + ), + torch.Tensor( + [ + [ + [0.0008, 0.0008], + [0.0232, 0.0581], + [0.0232, 0.0581], + [0.0232, 0.0581], + [0.0232, 0.0581], + ], + [ + [0.0008, 0.0008], + [0.0232, 0.0581], + [0.0232, 0.0581], + [0.0232, 0.0581], + [0.0232, 0.0581], + ], + ] + ), + ), + atol=1e-4, + rtol=0, + ) + assert_expected( + out.attentions, + ( + torch.Tensor( + [ + [ + [ + [0.2000, 0.2000, 0.2000, 0.2000, 0.2000], + [0.1999, 0.2000, 0.2000, 0.2000, 0.2000], + [0.1999, 0.2000, 0.2000, 0.2000, 0.2000], + [0.1999, 0.2000, 0.2000, 0.2000, 0.2000], + [0.1999, 0.2000, 0.2000, 0.2000, 0.2000], + ] + ], + [ + [ + [0.2000, 0.2000, 0.2000, 0.2000, 0.2000], + [0.1999, 0.2000, 0.2000, 0.2000, 0.2000], + [0.1999, 0.2000, 0.2000, 0.2000, 0.2000], + [0.1999, 0.2000, 0.2000, 0.2000, 0.2000], + [0.1999, 0.2000, 0.2000, 0.2000, 0.2000], + ] + ], + ] + ), + ), + atol=1e-4, + rtol=0, + ) diff --git a/test/models/flava/test_flava_text_encoder.py b/test/models/flava/test_flava_text_encoder.py new file mode 100644 index 00000000..c1a690d8 --- /dev/null +++ b/test/models/flava/test_flava_text_encoder.py @@ -0,0 +1,97 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from test.test_utils import assert_expected, set_rng_seed +from torch import nn +from torchmultimodal.models.flava.flava_text_encoder import ( + TextEmbeddings, + TextTransformer, +) +from torchmultimodal.modules.layers.transformer import FLAVATransformerEncoder + + +class TestFlavaTextEncoder(unittest.TestCase): + def setUp(self): + set_rng_seed(0) + self.text_embedding = TextEmbeddings( + hidden_size=2, + vocab_size=3, + max_position_embeddings=2, + hidden_dropout_prob=0, + ) + emb_weights = torch.Tensor([[0, 1], [1, 0], [1, 1]]) + self.text_embedding.word_embeddings = nn.Embedding.from_pretrained(emb_weights) + self.text_embedding.position_embeddings = nn.Embedding.from_pretrained( + emb_weights + ) + self.text_embedding.token_type_embeddings = nn.Embedding.from_pretrained( + emb_weights + ) + + encoder = FLAVATransformerEncoder( + hidden_size=2, + num_attention_heads=1, + num_hidden_layers=1, + hidden_dropout_prob=0.0, + intermediate_size=1, + attention_probs_dropout_prob=0.0, + ) + self.text_encoder = TextTransformer( + embeddings=self.text_embedding, + encoder=encoder, + layernorm=nn.LayerNorm(2), + pooler=nn.Identity(), + ) + + def test_embedding(self): + input_ids = torch.IntTensor([[0, 1]]) + out = self.text_embedding(input_ids) + expected = torch.Tensor([[[1.0, -1.0], [-1.0, 1.0]]]) + assert_expected(out, expected) + + def test_text_transformer(self): + out = self.text_encoder(torch.IntTensor([[0, 1]])) + + assert_expected( + out.last_hidden_state, torch.Tensor([[[1.0, -1.0], [-1.0, 1.0]]]) + ) + + assert_expected( + out.hidden_states, + ( + torch.Tensor([[[1.0000, -1.0000], [-1.0000, 1.0000]]]), + torch.Tensor([[[1.0008, -0.9994], [-0.9997, 1.0012]]]), + ), + atol=1e-4, + rtol=0.0, + ) + + assert_expected(out.attentions, (torch.Tensor([[[[0, 1.0], [0.0, 1.0]]]]),)) + + def test_text_transformer_attn_mask(self): + input_ids = torch.IntTensor([[0, 1]]) + attn_mask = torch.IntTensor([[1, 0]]) + out = self.text_encoder(input_ids, attention_mask=attn_mask) + + assert_expected( + out.last_hidden_state, torch.Tensor([[[1.0, -1.0], [-1.0, 1.0]]]) + ) + + assert_expected( + out.hidden_states, + ( + torch.Tensor([[[1.0, -1.0], [-1.0, 1.0]]]), + torch.Tensor([[[0.9997, -1.0012], [-1.0008, 0.9994]]]), + ), + atol=1e-4, + rtol=0.0, + ) + + assert_expected(out.pooler_output, torch.Tensor([[[1.0, -1.0], [-1.0, 1.0]]])) + assert_expected(out.attentions, (torch.Tensor([[[[1.0, 0], [1.0, 0]]]]),)) diff --git a/torchmultimodal/models/flava/flava_image_encoder.py b/torchmultimodal/models/flava/flava_image_encoder.py new file mode 100644 index 00000000..eabad7f7 --- /dev/null +++ b/torchmultimodal/models/flava/flava_image_encoder.py @@ -0,0 +1,304 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import collections +from functools import partial +from typing import Any, Callable, Optional + +import torch +from torch import nn, Tensor +from torchmultimodal.modules.layers.normalizations import Fp32LayerNorm +from torchmultimodal.modules.layers.transformer import ( + FLAVATransformerEncoder, + FLAVATransformerOutput, + init_transformer_weights, +) +from torchmultimodal.modules.losses.flava import Pooler + + +def to_2tuple(x): + if isinstance(x, collections.abc.Iterable): + return x + return (x, x) + + +# Based on timm implementation, which can be found here: +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +class PatchEmbeddings(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768): + super().__init__() + image_size = to_2tuple(image_size) + patch_size = to_2tuple(patch_size) + num_patches = (image_size[1] // patch_size[1]) * ( + image_size[0] // patch_size[0] + ) + self.image_size = image_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.projection = nn.Conv2d( + num_channels, embed_dim, kernel_size=patch_size, stride=patch_size + ) + + def forward(self, pixel_values, interpolate_pos_encoding=False): + _, _, height, width = pixel_values.shape + if not interpolate_pos_encoding: + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." + ) + x = self.projection(pixel_values).flatten(2).transpose(1, 2) + return x + + +class ImageEmbeddings(nn.Module): + """ + Construct the CLS token, position and patch embeddings. + """ + + def __init__( + self, + image_size: int = 224, + patch_size: int = 16, + num_channels: int = 3, + hidden_size: int = 768, + hidden_dropout_prob: float = 0.0, + use_image_masking: bool = True, + ): + super().__init__() + + self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size)) + self.patch_embeddings = PatchEmbeddings( + image_size=image_size, + patch_size=patch_size, + num_channels=num_channels, + embed_dim=hidden_size, + ) + num_patches = self.patch_embeddings.num_patches + self.position_embeddings = nn.Parameter( + torch.zeros(1, num_patches + 1, hidden_size) + ) + self.dropout = nn.Dropout(hidden_dropout_prob) + + if use_image_masking: + self.mask_token = nn.Parameter(torch.zeros(1, 1, hidden_size)) + else: + self.mask_token = None + + def interpolate_pos_encoding(self, embeddings, height, width): + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + npatch = embeddings.shape[1] - 1 + n = self.position_embeddings.shape[1] - 1 + if npatch == n and height == width: + return self.position_embeddings + class_pos_embed = self.position_embeddings[:, 0] + patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] + h0 = height // self.config.patch_size + w0 = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + h0, w0 = h0 + 0.1, w0 + 0.1 + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape( + 1, int(math.sqrt(N)), int(math.sqrt(N)), dim + ).permute(0, 3, 1, 2), + scale_factor=(h0 / math.sqrt(N), w0 / math.sqrt(N)), + mode="bicubic", + align_corners=False, + ) + assert ( + int(h0) == patch_pos_embed.shape[-2] + and int(w0) == patch_pos_embed.shape[-1] + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward( + self, + pixel_values: Tensor, + image_patches_mask: Optional[Tensor] = None, + interpolate_pos_encoding: bool = False, + ): + batch_size, num_channels, height, width = pixel_values.shape + embeddings = self.patch_embeddings( + pixel_values, interpolate_pos_encoding=interpolate_pos_encoding + ) + + _, seq_len, _ = embeddings.size() + if image_patches_mask is not None: + if self.mask_token is not None: + mask_tokens = self.mask_token.expand(batch_size, seq_len, -1) + # replace the masked visual tokens by mask_tokens + w = image_patches_mask.unsqueeze(-1).type_as(mask_tokens) + embeddings = embeddings * (1 - w) + mask_tokens * w + else: + warnings.warn( + "image_patches_mask passed but use_image_masking in init was false. Ignoring." + ) + # add the [CLS] token to the embedded patch tokens + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + + # add positional encoding to each token + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding( + embeddings, height, width + ) + else: + embeddings = embeddings + self.position_embeddings + + embeddings = self.dropout(embeddings) + + return embeddings + + +class ImageTransformer(nn.Module): + # TODO(asg): Add support for pretrained checkpoint loading + def __init__( + self, + embeddings: nn.Module, + encoder: nn.Module, + layernorm: nn.Module, + pooler: nn.Module, + weight_init_fn: Optional[Callable] = None, + initializer_range: float = 0.02, + **kwargs: Any, + ): + super().__init__() + + self.embeddings = embeddings + self.encoder = encoder + self.layernorm = layernorm + self.pooler = pooler + + if weight_init_fn is None: + weight_init_fn = partial( + init_transformer_weights, initializer_range=initializer_range + ) + + self.apply(weight_init_fn) + + def forward( + self, + pixel_values: Optional[Tensor] = None, + image_patches_mask: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + ): + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + embedding_output = self.embeddings( + pixel_values, image_patches_mask=image_patches_mask + ) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=attention_mask, + ) + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = ( + self.pooler(sequence_output) if self.pooler is not None else None + ) + + return FLAVATransformerOutput( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +def flava_image_encoder( + hidden_size: int = 768, + num_attention_heads: int = 12, + num_hidden_layers: int = 12, + use_image_masking: bool = False, + hidden_dropout_prob: float = 0.0, + intermediate_size: int = 3072, + intermediate_activation: Callable[..., Tensor] = nn.functional.gelu, + attention_probs_dropout_prob: float = 0.0, + layer_norm_eps: float = 1e-12, + image_size: int = 224, + patch_size: int = 16, + num_channels: int = 3, +): + + embeddings = ImageEmbeddings( + image_size=image_size, + patch_size=patch_size, + num_channels=num_channels, + hidden_size=hidden_size, + hidden_dropout_prob=hidden_dropout_prob, + use_image_masking=use_image_masking, + ) + encoder = FLAVATransformerEncoder( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_hidden_layers=num_hidden_layers, + hidden_dropout_prob=hidden_dropout_prob, + intermediate_size=intermediate_size, + intermediate_activation=intermediate_activation, + attention_probs_dropout_prob=attention_probs_dropout_prob, + layer_norm_eps=layer_norm_eps, + ) + + layernorm = Fp32LayerNorm(hidden_size, eps=layer_norm_eps) + pooler = Pooler(hidden_size=hidden_size) + + return ImageTransformer( + embeddings=embeddings, + encoder=encoder, + layernorm=layernorm, + pooler=pooler, + ) + + +class ImageTransformerWithVAE(nn.Module): + def __init__( + self, + image_transformer: nn.Module, + vae: nn.Module, + **kwargs, + ): + super().__init__() + + self.image_transformer = image_transformer + self.vae = vae + + def forward( + self, + pixel_values: Optional[Tensor] = None, + image_patches_mask: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + ): + image_labels = self.vae(pixel_values).flatten(1) + image_patches_mask = image_patches_mask.flatten(1).to(torch.bool) + image_labels[image_patches_mask == False] = -1 # noqa + + output = self.image_transformer( + pixel_values=pixel_values, + image_patches_mask=image_patches_mask, + attention_mask=attention_mask, + ) + return FLAVATransformerOutput( + last_hidden_state=output.last_hidden_state, + pooler_output=output.pooler_output, + hidden_states=output.hidden_states, + attentions=output.attentions, + image_labels=image_labels, + ) diff --git a/torchmultimodal/models/flava/flava_model.py b/torchmultimodal/models/flava/flava_model.py index 2598d954..9c74f527 100644 --- a/torchmultimodal/models/flava/flava_model.py +++ b/torchmultimodal/models/flava/flava_model.py @@ -9,22 +9,21 @@ import collections import math -import warnings from collections import namedtuple, OrderedDict from dataclasses import dataclass from functools import partial from typing import Any, Callable, List, Literal, Optional, Tuple, Union import torch -from packaging import version -from torch import device, nn, Tensor +from torch import nn, Tensor +from torchmultimodal.models.flava.flava_image_encoder import flava_image_encoder +from torchmultimodal.models.flava.flava_text_encoder import flava_text_encoder from torchmultimodal.modules.layers.mlp import MLP from torchmultimodal.modules.layers.normalizations import Fp32LayerNorm from torchmultimodal.modules.layers.transformer import ( FLAVATransformerEncoder, FLAVATransformerOutput, FLAVATransformerWithoutEmbeddings, - init_transformer_weights, ) from torchmultimodal.modules.losses.flava import ( FLAVAPretrainingLoss, @@ -56,98 +55,6 @@ } -def flava_image_encoder( - hidden_size: int = 768, - num_attention_heads: int = 12, - num_hidden_layers: int = 12, - use_image_masking: bool = False, - hidden_dropout_prob: float = 0.0, - intermediate_size: int = 3072, - intermediate_activation: Callable[..., Tensor] = nn.functional.gelu, - attention_probs_dropout_prob: float = 0.0, - layer_norm_eps: float = 1e-12, - image_size: int = 224, - patch_size: int = 16, - num_channels: int = 3, -): - - embeddings = ImageEmbeddings( - image_size=image_size, - patch_size=patch_size, - num_channels=num_channels, - hidden_size=hidden_size, - hidden_dropout_prob=hidden_dropout_prob, - use_image_masking=use_image_masking, - ) - encoder = FLAVATransformerEncoder( - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - num_hidden_layers=num_hidden_layers, - hidden_dropout_prob=hidden_dropout_prob, - intermediate_size=intermediate_size, - intermediate_activation=intermediate_activation, - attention_probs_dropout_prob=attention_probs_dropout_prob, - layer_norm_eps=layer_norm_eps, - ) - - layernorm = Fp32LayerNorm(hidden_size, eps=layer_norm_eps) - pooler = Pooler(hidden_size=hidden_size) - - return ImageTransformer( - embeddings=embeddings, - encoder=encoder, - layernorm=layernorm, - pooler=pooler, - ) - - -def flava_text_encoder( - hidden_size: int = 768, - num_attention_heads: int = 12, - num_hidden_layers: int = 12, - hidden_dropout_prob: float = 0.0, - intermediate_size: int = 3072, - intermediate_activation: Callable[..., Tensor] = nn.functional.gelu, - attention_probs_dropout_prob: float = 0.0, - layer_norm_eps: float = 1e-12, - vocab_size: int = 30522, - pad_token_id: int = 0, - type_vocab_size: int = 2, - max_position_embeddings: int = 512, -): - embeddings = TextEmbeddings( - hidden_size=hidden_size, - vocab_size=vocab_size, - pad_token_id=pad_token_id, - type_vocab_size=type_vocab_size, - max_position_embeddings=max_position_embeddings, - layer_norm_eps=layer_norm_eps, - hidden_dropout_prob=hidden_dropout_prob, - ) - - encoder = FLAVATransformerEncoder( - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - num_hidden_layers=num_hidden_layers, - hidden_dropout_prob=hidden_dropout_prob, - intermediate_size=intermediate_size, - intermediate_activation=intermediate_activation, - attention_probs_dropout_prob=attention_probs_dropout_prob, - layer_norm_eps=layer_norm_eps, - pad_token_id=pad_token_id, - ) - - layernorm = Fp32LayerNorm(hidden_size, eps=layer_norm_eps) - pooler = Pooler(hidden_size=hidden_size) - - return TextTransformer( - embeddings=embeddings, - encoder=encoder, - layernorm=layernorm, - pooler=pooler, - ) - - def flava_multimodal_encoder( hidden_size: int = 768, num_attention_heads: int = 12, @@ -592,439 +499,6 @@ def forward( ) -# Based on timm implementation, which can be found here: -# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py -class PatchEmbeddings(nn.Module): - """ - Image to Patch Embedding. - """ - - def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768): - super().__init__() - image_size = to_2tuple(image_size) - patch_size = to_2tuple(patch_size) - num_patches = (image_size[1] // patch_size[1]) * ( - image_size[0] // patch_size[0] - ) - self.image_size = image_size - self.patch_size = patch_size - self.num_patches = num_patches - - self.projection = nn.Conv2d( - num_channels, embed_dim, kernel_size=patch_size, stride=patch_size - ) - - def forward(self, pixel_values, interpolate_pos_encoding=False): - _, _, height, width = pixel_values.shape - if not interpolate_pos_encoding: - if height != self.image_size[0] or width != self.image_size[1]: - raise ValueError( - f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." - ) - x = self.projection(pixel_values).flatten(2).transpose(1, 2) - return x - - -class ImageEmbeddings(nn.Module): - """ - Construct the CLS token, position and patch embeddings. - """ - - def __init__( - self, - image_size: int = 224, - patch_size: int = 16, - num_channels: int = 3, - hidden_size: int = 768, - hidden_dropout_prob: float = 0.0, - use_image_masking: bool = True, - ): - super().__init__() - - self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size)) - self.patch_embeddings = PatchEmbeddings( - image_size=image_size, - patch_size=patch_size, - num_channels=num_channels, - embed_dim=hidden_size, - ) - num_patches = self.patch_embeddings.num_patches - self.position_embeddings = nn.Parameter( - torch.zeros(1, num_patches + 1, hidden_size) - ) - self.dropout = nn.Dropout(hidden_dropout_prob) - - if use_image_masking: - self.mask_token = nn.Parameter(torch.zeros(1, 1, hidden_size)) - else: - self.mask_token = None - - def interpolate_pos_encoding(self, embeddings, height, width): - """ - This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher - resolution images. - Source: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 - """ - - npatch = embeddings.shape[1] - 1 - n = self.position_embeddings.shape[1] - 1 - if npatch == n and height == width: - return self.position_embeddings - class_pos_embed = self.position_embeddings[:, 0] - patch_pos_embed = self.position_embeddings[:, 1:] - dim = embeddings.shape[-1] - h0 = height // self.config.patch_size - w0 = width // self.config.patch_size - # we add a small number to avoid floating point error in the interpolation - # see discussion at https://github.com/facebookresearch/dino/issues/8 - h0, w0 = h0 + 0.1, w0 + 0.1 - patch_pos_embed = nn.functional.interpolate( - patch_pos_embed.reshape( - 1, int(math.sqrt(n)), int(math.sqrt(n)), dim - ).permute(0, 3, 1, 2), - scale_factor=(h0 / math.sqrt(n), w0 / math.sqrt(n)), - mode="bicubic", - align_corners=False, - ) - assert ( - int(h0) == patch_pos_embed.shape[-2] - and int(w0) == patch_pos_embed.shape[-1] - ) - patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) - - def forward( - self, - pixel_values: Tensor, - image_patches_mask: Optional[Tensor] = None, - interpolate_pos_encoding: bool = False, - ): - batch_size, num_channels, height, width = pixel_values.shape - embeddings = self.patch_embeddings( - pixel_values, interpolate_pos_encoding=interpolate_pos_encoding - ) - - _, seq_len, _ = embeddings.size() - if image_patches_mask is not None: - if self.mask_token is not None: - mask_tokens = self.mask_token.expand(batch_size, seq_len, -1) - # replace the masked visual tokens by mask_tokens - w = image_patches_mask.unsqueeze(-1).type_as(mask_tokens) - embeddings = embeddings * (1 - w) + mask_tokens * w - else: - warnings.warn( - "image_patches_mask passed but use_image_masking in init was false. Ignoring." - ) - # add the [CLS] token to the embedded patch tokens - cls_tokens = self.cls_token.expand(batch_size, -1, -1) - embeddings = torch.cat((cls_tokens, embeddings), dim=1) - - # add positional encoding to each token - if interpolate_pos_encoding: - embeddings = embeddings + self.interpolate_pos_encoding( - embeddings, height, width - ) - else: - embeddings = embeddings + self.position_embeddings - - embeddings = self.dropout(embeddings) - - return embeddings - - -class ImageTransformer(nn.Module): - # TODO(asg): Add support for pretrained checkpoint loading - def __init__( - self, - embeddings: nn.Module, - encoder: nn.Module, - layernorm: nn.Module, - pooler: nn.Module, - weight_init_fn: Optional[Callable] = None, - initializer_range: float = 0.02, - **kwargs: Any, - ): - super().__init__() - - self.embeddings = embeddings - self.encoder = encoder - self.layernorm = layernorm - self.pooler = pooler - - if weight_init_fn is None: - weight_init_fn = partial( - init_transformer_weights, initializer_range=initializer_range - ) - - self.apply(weight_init_fn) - - def forward( - self, - pixel_values: Optional[Tensor] = None, - image_patches_mask: Optional[Tensor] = None, - attention_mask: Optional[Tensor] = None, - ): - if pixel_values is None: - raise ValueError("You have to specify pixel_values") - - embedding_output = self.embeddings( - pixel_values, image_patches_mask=image_patches_mask - ) - - encoder_outputs = self.encoder( - embedding_output, - attention_mask=attention_mask, - ) - sequence_output = encoder_outputs[0] - sequence_output = self.layernorm(sequence_output) - pooled_output = ( - self.pooler(sequence_output) if self.pooler is not None else None - ) - - return FLAVATransformerOutput( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - -class ImageTransformerWithVAE(nn.Module): - def __init__( - self, - image_transformer: nn.Module, - vae: nn.Module, - **kwargs, - ): - super().__init__() - - self.image_transformer = image_transformer - self.vae = vae - - def forward( - self, - pixel_values: Optional[Tensor] = None, - image_patches_mask: Optional[Tensor] = None, - attention_mask: Optional[Tensor] = None, - ): - image_labels = self.vae(pixel_values).flatten(1) - image_patches_mask = image_patches_mask.flatten(1).to(torch.bool) - image_labels[image_patches_mask == False] = -1 # noqa - - output = self.image_transformer( - pixel_values=pixel_values, - image_patches_mask=image_patches_mask, - attention_mask=attention_mask, - ) - return FLAVATransformerOutput( - last_hidden_state=output.last_hidden_state, - pooler_output=output.pooler_output, - hidden_states=output.hidden_states, - attentions=output.attentions, - image_labels=image_labels, - ) - - -class TextEmbeddings(nn.Module): - """Construct the embeddings from word, position and token_type embeddings following BERT.""" - - def __init__( - self, - hidden_size: int = 768, - vocab_size: int = 30522, - pad_token_id: int = 0, - type_vocab_size: int = 2, - max_position_embeddings: int = 512, - layer_norm_eps: float = 1e-12, - hidden_dropout_prob: float = 0.1, - ): - super().__init__() - self.word_embeddings = nn.Embedding( - vocab_size, hidden_size, padding_idx=pad_token_id - ) - self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size) - self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size) - - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file - self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) - self.dropout = nn.Dropout(hidden_dropout_prob) - # position_ids (1, len position emb) is contiguous in memory and exported when serialized - self.register_buffer( - "position_ids", torch.arange(max_position_embeddings).expand((1, -1)) - ) - if version.parse(torch.__version__) > version.parse("1.6.0"): - self.register_buffer( - "token_type_ids", - torch.zeros(self.position_ids.size(), dtype=torch.long), - persistent=False, - ) - - def forward( - self, - input_ids: Optional[Tensor] = None, - token_type_ids: Optional[Tensor] = None, - position_ids: Optional[Tensor] = None, - inputs_embeds: Optional[Tensor] = None, - past_key_values_length: int = 0, - ): - if input_ids is not None: - input_shape = input_ids.size() - else: - input_shape = inputs_embeds.size()[:-1] - - seq_length = input_shape[1] - - if position_ids is None: - position_ids = self.position_ids[ - :, past_key_values_length : seq_length + past_key_values_length - ] - - # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs - # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves - # issue #5664 - if token_type_ids is None: - if hasattr(self, "token_type_ids"): - buffered_token_type_ids = self.token_type_ids[:, :seq_length] - buffered_token_type_ids_expanded = buffered_token_type_ids.expand( - input_shape[0], seq_length - ) - token_type_ids = buffered_token_type_ids_expanded - else: - token_type_ids = torch.zeros( - input_shape, dtype=torch.long, device=self.position_ids.device - ) - - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) - token_type_embeddings = self.token_type_embeddings(token_type_ids) - - embeddings = inputs_embeds + token_type_embeddings - position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings - - embeddings = self.LayerNorm(embeddings) - embeddings = self.dropout(embeddings) - return embeddings - - -class TextTransformer(nn.Module): - # TODO(asg): Add support for pretrained checkpoint loading - def __init__( - self, - embeddings: nn.Module, - encoder: nn.Module, - layernorm: nn.Module, - pooler: nn.Module, - weight_init_fn: Optional[Callable] = None, - initializer_range: float = 0.02, - pad_token_id: int = 0, - **kwargs: Any, - ): - super().__init__() - - self.embeddings = embeddings - self.encoder = encoder - self.layernorm = layernorm - self.pooler = pooler - self.pad_token_id = pad_token_id - - if weight_init_fn is None: - weight_init_fn = partial( - init_transformer_weights, initializer_range=initializer_range - ) - - self.apply(weight_init_fn) - - def get_extended_attention_mask( - self, attention_mask: Tensor, input_shape: Tuple[int], device: device - ) -> Tensor: - """ - Makes broadcastable attention and causal masks so that future and masked tokens are ignored. - Arguments: - attention_mask (`torch.Tensor`): - Mask with ones indicating tokens to attend to, zeros for tokens to ignore. - input_shape (`Tuple[int]`): - The shape of the input to the model. - device: (`torch.device`): - The device of the input to the model. - Returns: - `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`. - """ - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - if attention_mask.dim() == 3: - extended_attention_mask = attention_mask[:, None, :, :] - elif attention_mask.dim() == 2: - # Provided a padding mask of dimensions [batch_size, seq_length] - # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] - extended_attention_mask = attention_mask[:, None, None, :] - else: - raise ValueError( - f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" - ) - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - extended_attention_mask = extended_attention_mask.to( - dtype=attention_mask.dtype - ) # fp16 compatibility - extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 - return extended_attention_mask - - def forward( - self, - input_ids: Optional[Tensor] = None, - token_type_ids: Optional[Tensor] = None, - position_ids: Optional[Tensor] = None, - attention_mask: Optional[Tensor] = None, - ): - if input_ids is None: - raise ValueError("You have to specify input_ids") - input_shape = input_ids.size() - device = input_ids.device - - if attention_mask is None: - attention_mask = torch.ones(input_shape, device=device) - attention_mask[input_ids == self.pad_token_id] = 0 - if token_type_ids is None: - token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) - - # We can provide a self-attention mask of dimensions - # [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( - attention_mask, input_shape, device - ) - - embedding_output = self.embeddings( - input_ids=input_ids, - token_type_ids=token_type_ids, - position_ids=position_ids, - ) - - encoder_outputs = self.encoder( - embedding_output, - attention_mask=extended_attention_mask, - ) - sequence_output = encoder_outputs[0] - sequence_output = self.layernorm(sequence_output) - pooled_output = ( - self.pooler(sequence_output) if self.pooler is not None else None - ) - - return FLAVATransformerOutput( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - class DalleConv2d(nn.Module): def __init__(self, n_in: int, n_out: int, kw: int): super().__init__() diff --git a/torchmultimodal/models/flava/flava_text_encoder.py b/torchmultimodal/models/flava/flava_text_encoder.py new file mode 100644 index 00000000..70eb3039 --- /dev/null +++ b/torchmultimodal/models/flava/flava_text_encoder.py @@ -0,0 +1,264 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from functools import partial +from typing import Any, Callable, Optional, Tuple + +import torch +from packaging import version +from torch import device, nn, Tensor +from torchmultimodal.modules.layers.normalizations import Fp32LayerNorm +from torchmultimodal.modules.layers.transformer import ( + FLAVATransformerEncoder, + FLAVATransformerOutput, + init_transformer_weights, +) +from torchmultimodal.modules.losses.flava import Pooler + + +class TextEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings following BERT.""" + + def __init__( + self, + hidden_size: int = 768, + vocab_size: int = 30522, + pad_token_id: int = 0, + type_vocab_size: int = 2, + max_position_embeddings: int = 512, + layer_norm_eps: float = 1e-12, + hidden_dropout_prob: float = 0.1, + ): + super().__init__() + self.word_embeddings = nn.Embedding( + vocab_size, hidden_size, padding_idx=pad_token_id + ) + self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size) + self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) + self.dropout = nn.Dropout(hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(max_position_embeddings).expand((1, -1)) + ) + if version.parse(torch.__version__) > version.parse("1.6.0"): + self.register_buffer( + "token_type_ids", + torch.zeros(self.position_ids.size(), dtype=torch.long), + persistent=False, + ) + + def forward( + self, + input_ids: Optional[Tensor] = None, + token_type_ids: Optional[Tensor] = None, + position_ids: Optional[Tensor] = None, + inputs_embeds: Optional[Tensor] = None, + past_key_values_length: int = 0, + ): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[ + :, past_key_values_length : seq_length + past_key_values_length + ] + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand( + input_shape[0], seq_length + ) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros( + input_shape, dtype=torch.long, device=self.position_ids.device + ) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class TextTransformer(nn.Module): + # TODO(asg): Add support for pretrained checkpoint loading + def __init__( + self, + embeddings: nn.Module, + encoder: nn.Module, + layernorm: nn.Module, + pooler: nn.Module, + weight_init_fn: Optional[Callable] = None, + initializer_range: float = 0.02, + pad_token_id: int = 0, + **kwargs: Any, + ): + super().__init__() + + self.embeddings = embeddings + self.encoder = encoder + self.layernorm = layernorm + self.pooler = pooler + self.pad_token_id = pad_token_id + + if weight_init_fn is None: + weight_init_fn = partial( + init_transformer_weights, initializer_range=initializer_range + ) + + self.apply(weight_init_fn) + + def get_extended_attention_mask( + self, attention_mask: Tensor, input_shape: Tuple[int], device: device + ) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + Arguments: + attention_mask (`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (`Tuple[int]`): + The shape of the input to the model. + device: (`torch.device`): + The device of the input to the model. + Returns: + `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=attention_mask.dtype + ) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids: Optional[Tensor] = None, + token_type_ids: Optional[Tensor] = None, + position_ids: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + ): + if input_ids is None: + raise ValueError("You have to specify input_ids") + input_shape = input_ids.size() + device = input_ids.device + + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=device) + attention_mask[input_ids == self.pad_token_id] = 0 + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions + # [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( + attention_mask, input_shape, device + ) + + embedding_output = self.embeddings( + input_ids=input_ids, + token_type_ids=token_type_ids, + position_ids=position_ids, + ) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + ) + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = ( + self.pooler(sequence_output) if self.pooler is not None else None + ) + + return FLAVATransformerOutput( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +def flava_text_encoder( + hidden_size: int = 768, + num_attention_heads: int = 12, + num_hidden_layers: int = 12, + hidden_dropout_prob: float = 0.0, + intermediate_size: int = 3072, + intermediate_activation: Callable[..., Tensor] = nn.functional.gelu, + attention_probs_dropout_prob: float = 0.0, + layer_norm_eps: float = 1e-12, + vocab_size: int = 30522, + pad_token_id: int = 0, + type_vocab_size: int = 2, + max_position_embeddings: int = 512, +): + embeddings = TextEmbeddings( + hidden_size=hidden_size, + vocab_size=vocab_size, + pad_token_id=pad_token_id, + type_vocab_size=type_vocab_size, + max_position_embeddings=max_position_embeddings, + layer_norm_eps=layer_norm_eps, + hidden_dropout_prob=hidden_dropout_prob, + ) + + encoder = FLAVATransformerEncoder( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_hidden_layers=num_hidden_layers, + hidden_dropout_prob=hidden_dropout_prob, + intermediate_size=intermediate_size, + intermediate_activation=intermediate_activation, + attention_probs_dropout_prob=attention_probs_dropout_prob, + layer_norm_eps=layer_norm_eps, + pad_token_id=pad_token_id, + ) + + layernorm = Fp32LayerNorm(hidden_size, eps=layer_norm_eps) + pooler = Pooler(hidden_size=hidden_size) + + return TextTransformer( + embeddings=embeddings, + encoder=encoder, + layernorm=layernorm, + pooler=pooler, + ) From ca413041fd4ce679c9783ee902e1919717392060 Mon Sep 17 00:00:00 2001 From: ankitade Date: Thu, 23 Jun 2022 05:02:09 +0000 Subject: [PATCH 2/3] Update on "[FLAVA] Separate out text and image encoders" [ghstack-poisoned] --- mypy.ini | 2 +- torchmultimodal/models/flava/flava_image_encoder.py | 11 ++++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/mypy.ini b/mypy.ini index 149afdea..304c4392 100644 --- a/mypy.ini +++ b/mypy.ini @@ -14,7 +14,7 @@ namespace_packages = True install_types = True # TODO (T116951827): Remove after fixing FLAVA type check errors -exclude = models/flava/flava_model.py|modules/losses/flava.py +exclude = models/flava/flava_model.py| models/flava/flava_text_encoder.py|modules/losses/flava.py [mypy-PIL.*] ignore_missing_imports = True diff --git a/torchmultimodal/models/flava/flava_image_encoder.py b/torchmultimodal/models/flava/flava_image_encoder.py index eabad7f7..7e00dc39 100644 --- a/torchmultimodal/models/flava/flava_image_encoder.py +++ b/torchmultimodal/models/flava/flava_image_encoder.py @@ -5,6 +5,8 @@ # LICENSE file in the root directory of this source tree. import collections +import math +import warnings from functools import partial from typing import Any, Callable, Optional @@ -107,16 +109,16 @@ def interpolate_pos_encoding(self, embeddings, height, width): class_pos_embed = self.position_embeddings[:, 0] patch_pos_embed = self.position_embeddings[:, 1:] dim = embeddings.shape[-1] - h0 = height // self.config.patch_size - w0 = width // self.config.patch_size + h0 = height // self.patch_embeddings.patch_size + w0 = width // self.patch_embeddings.patch_size # we add a small number to avoid floating point error in the interpolation # see discussion at https://github.com/facebookresearch/dino/issues/8 h0, w0 = h0 + 0.1, w0 + 0.1 patch_pos_embed = nn.functional.interpolate( patch_pos_embed.reshape( - 1, int(math.sqrt(N)), int(math.sqrt(N)), dim + 1, int(math.sqrt(n)), int(math.sqrt(n)), dim ).permute(0, 3, 1, 2), - scale_factor=(h0 / math.sqrt(N), w0 / math.sqrt(N)), + scale_factor=(h0 / math.sqrt(n), w0 / math.sqrt(n)), mode="bicubic", align_corners=False, ) @@ -300,5 +302,4 @@ def forward( pooler_output=output.pooler_output, hidden_states=output.hidden_states, attentions=output.attentions, - image_labels=image_labels, ) From fbb769b5443ec2b525e2195fa5356adc29fe8aa5 Mon Sep 17 00:00:00 2001 From: ankitade Date: Thu, 23 Jun 2022 05:08:15 +0000 Subject: [PATCH 3/3] Update on "[FLAVA] Separate out text and image encoders" Separate out the encoders into their own module without ay logic changes (except fixing 2 minor bugs, see annotations by me) and add tests [ghstack-poisoned] --- mypy.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy.ini b/mypy.ini index 304c4392..dbcaf89b 100644 --- a/mypy.ini +++ b/mypy.ini @@ -14,7 +14,7 @@ namespace_packages = True install_types = True # TODO (T116951827): Remove after fixing FLAVA type check errors -exclude = models/flava/flava_model.py| models/flava/flava_text_encoder.py|modules/losses/flava.py +exclude = models/flava/flava_model.py|models/flava/flava_text_encoder.py|modules/losses/flava.py [mypy-PIL.*] ignore_missing_imports = True