- 
                Notifications
    You must be signed in to change notification settings 
- Fork 158
[FLAVA] Separate out text and image encoders #102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | 
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,167 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|  | ||
| import unittest | ||
|  | ||
| import torch | ||
| from test.test_utils import assert_expected, set_rng_seed | ||
| from torch import nn | ||
| from torchmultimodal.models.flava.flava_image_encoder import ( | ||
| ImageEmbeddings, | ||
| ImageTransformer, | ||
| ) | ||
| from torchmultimodal.modules.layers.transformer import FLAVATransformerEncoder | ||
|  | ||
|  | ||
| class TestFlavaImageEncoder(unittest.TestCase): | ||
| def setUp(self): | ||
| set_rng_seed(0) | ||
| torch.manual_seed(0) | ||
| self.image_embedding = ImageEmbeddings( | ||
| image_size=2, patch_size=1, hidden_size=2 | ||
| ) | ||
|  | ||
| encoder = FLAVATransformerEncoder( | ||
| hidden_size=2, | ||
| num_attention_heads=1, | ||
| num_hidden_layers=1, | ||
| hidden_dropout_prob=0.0, | ||
| intermediate_size=1, | ||
| attention_probs_dropout_prob=0.0, | ||
| ) | ||
| self.image_encoder = ImageTransformer( | ||
| embeddings=self.image_embedding, | ||
| encoder=encoder, | ||
| layernorm=nn.LayerNorm(2), | ||
| pooler=nn.Identity(), | ||
| ) | ||
|  | ||
| def test_embedding(self): | ||
| input = torch.ones(2, 3, 2, 2) | ||
| out = self.image_embedding(input) | ||
| assert_expected( | ||
| out, | ||
| torch.Tensor( | ||
| [ | ||
| [ | ||
| [0.0000, 0.0000], | ||
| [0.0224, 0.0573], | ||
| [0.0224, 0.0573], | ||
| [0.0224, 0.0573], | ||
| [0.0224, 0.0573], | ||
| ], | ||
| [ | ||
| [0.0000, 0.0000], | ||
| [0.0224, 0.0573], | ||
| [0.0224, 0.0573], | ||
| [0.0224, 0.0573], | ||
| [0.0224, 0.0573], | ||
| ], | ||
| ] | ||
| ), | ||
| atol=1e-4, | ||
| rtol=0, | ||
| ) | ||
|  | ||
| def test_image_encoder(self): | ||
| input = torch.ones(2, 3, 2, 2) | ||
| out = self.image_encoder(input) | ||
| assert_expected( | ||
| out.last_hidden_state, | ||
| torch.Tensor( | ||
| [ | ||
| [ | ||
| [-0.0040, 0.0040], | ||
| [-0.9840, 0.9840], | ||
| [-0.9840, 0.9840], | ||
| [-0.9840, 0.9840], | ||
| [-0.9840, 0.9840], | ||
| ], | ||
| [ | ||
| [-0.0040, 0.0040], | ||
| [-0.9840, 0.9840], | ||
| [-0.9840, 0.9840], | ||
| [-0.9840, 0.9840], | ||
| [-0.9840, 0.9840], | ||
| ], | ||
| ] | ||
| ), | ||
| atol=1e-4, | ||
| rtol=0, | ||
| ) | ||
| assert_expected(out.pooler_output, out.last_hidden_state) | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: maybe a bit confusing to do the transitive thing here. Can you just set the expected result to a var and compare both last_hidden_state and pooler_output to that? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. isnt the transitive thing actually making it clear which values should line up There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, just personal preference I guess | ||
| assert_expected( | ||
| out.hidden_states, | ||
| ( | ||
| torch.Tensor( | ||
| [ | ||
| [ | ||
| [0.0000, 0.0000], | ||
| [0.0224, 0.0573], | ||
| [0.0224, 0.0573], | ||
| [0.0224, 0.0573], | ||
| [0.0224, 0.0573], | ||
| ], | ||
| [ | ||
| [0.0000, 0.0000], | ||
| [0.0224, 0.0573], | ||
| [0.0224, 0.0573], | ||
| [0.0224, 0.0573], | ||
| [0.0224, 0.0573], | ||
| ], | ||
| ] | ||
| ), | ||
| torch.Tensor( | ||
| [ | ||
| [ | ||
| [0.0008, 0.0008], | ||
| [0.0232, 0.0581], | ||
| [0.0232, 0.0581], | ||
| [0.0232, 0.0581], | ||
| [0.0232, 0.0581], | ||
| ], | ||
| [ | ||
| [0.0008, 0.0008], | ||
| [0.0232, 0.0581], | ||
| [0.0232, 0.0581], | ||
| [0.0232, 0.0581], | ||
| [0.0232, 0.0581], | ||
| ], | ||
| ] | ||
| ), | ||
| ), | ||
| atol=1e-4, | ||
| rtol=0, | ||
| ) | ||
| assert_expected( | ||
| out.attentions, | ||
| ( | ||
| torch.Tensor( | ||
| [ | ||
| [ | ||
| [ | ||
| [0.2000, 0.2000, 0.2000, 0.2000, 0.2000], | ||
| [0.1999, 0.2000, 0.2000, 0.2000, 0.2000], | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 0.1999 due to rounding error? Maybe just make them all 0.2 for readability? | ||
| [0.1999, 0.2000, 0.2000, 0.2000, 0.2000], | ||
| [0.1999, 0.2000, 0.2000, 0.2000, 0.2000], | ||
| [0.1999, 0.2000, 0.2000, 0.2000, 0.2000], | ||
| ] | ||
| ], | ||
| [ | ||
| [ | ||
| [0.2000, 0.2000, 0.2000, 0.2000, 0.2000], | ||
| [0.1999, 0.2000, 0.2000, 0.2000, 0.2000], | ||
| [0.1999, 0.2000, 0.2000, 0.2000, 0.2000], | ||
| [0.1999, 0.2000, 0.2000, 0.2000, 0.2000], | ||
| [0.1999, 0.2000, 0.2000, 0.2000, 0.2000], | ||
| ] | ||
| ], | ||
| ] | ||
| ), | ||
| ), | ||
| atol=1e-4, | ||
| rtol=0, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,97 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|  | ||
| import unittest | ||
|  | ||
| import torch | ||
| from test.test_utils import assert_expected, set_rng_seed | ||
| from torch import nn | ||
| from torchmultimodal.models.flava.flava_text_encoder import ( | ||
| TextEmbeddings, | ||
| TextTransformer, | ||
| ) | ||
| from torchmultimodal.modules.layers.transformer import FLAVATransformerEncoder | ||
|  | ||
|  | ||
| class TestFlavaTextEncoder(unittest.TestCase): | ||
| def setUp(self): | ||
| set_rng_seed(0) | ||
| self.text_embedding = TextEmbeddings( | ||
| hidden_size=2, | ||
| vocab_size=3, | ||
| max_position_embeddings=2, | ||
| hidden_dropout_prob=0, | ||
| ) | ||
| emb_weights = torch.Tensor([[0, 1], [1, 0], [1, 1]]) | ||
| self.text_embedding.word_embeddings = nn.Embedding.from_pretrained(emb_weights) | ||
| self.text_embedding.position_embeddings = nn.Embedding.from_pretrained( | ||
| emb_weights | ||
| ) | ||
| self.text_embedding.token_type_embeddings = nn.Embedding.from_pretrained( | ||
| emb_weights | ||
| ) | ||
|  | ||
| encoder = FLAVATransformerEncoder( | ||
| hidden_size=2, | ||
| num_attention_heads=1, | ||
| num_hidden_layers=1, | ||
| hidden_dropout_prob=0.0, | ||
| intermediate_size=1, | ||
| attention_probs_dropout_prob=0.0, | ||
| ) | ||
| self.text_encoder = TextTransformer( | ||
| embeddings=self.text_embedding, | ||
| encoder=encoder, | ||
| layernorm=nn.LayerNorm(2), | ||
| pooler=nn.Identity(), | ||
| ) | ||
|  | ||
| def test_embedding(self): | ||
| input_ids = torch.IntTensor([[0, 1]]) | ||
| out = self.text_embedding(input_ids) | ||
| expected = torch.Tensor([[[1.0, -1.0], [-1.0, 1.0]]]) | ||
| assert_expected(out, expected) | ||
|  | ||
| def test_text_transformer(self): | ||
| out = self.text_encoder(torch.IntTensor([[0, 1]])) | ||
|  | ||
| assert_expected( | ||
| out.last_hidden_state, torch.Tensor([[[1.0, -1.0], [-1.0, 1.0]]]) | ||
| ) | ||
|  | ||
| assert_expected( | ||
| out.hidden_states, | ||
| ( | ||
| torch.Tensor([[[1.0000, -1.0000], [-1.0000, 1.0000]]]), | ||
| torch.Tensor([[[1.0008, -0.9994], [-0.9997, 1.0012]]]), | ||
| ), | ||
| atol=1e-4, | ||
| rtol=0.0, | ||
| ) | ||
|  | ||
| assert_expected(out.attentions, (torch.Tensor([[[[0, 1.0], [0.0, 1.0]]]]),)) | ||
|  | ||
| def test_text_transformer_attn_mask(self): | ||
| input_ids = torch.IntTensor([[0, 1]]) | ||
| attn_mask = torch.IntTensor([[1, 0]]) | ||
| out = self.text_encoder(input_ids, attention_mask=attn_mask) | ||
|  | ||
| assert_expected( | ||
| out.last_hidden_state, torch.Tensor([[[1.0, -1.0], [-1.0, 1.0]]]) | ||
| ) | ||
|  | ||
| assert_expected( | ||
| out.hidden_states, | ||
| ( | ||
| torch.Tensor([[[1.0, -1.0], [-1.0, 1.0]]]), | ||
| torch.Tensor([[[0.9997, -1.0012], [-1.0008, 0.9994]]]), | ||
| ), | ||
| atol=1e-4, | ||
| rtol=0.0, | ||
| ) | ||
|  | ||
| assert_expected(out.pooler_output, torch.Tensor([[[1.0, -1.0], [-1.0, 1.0]]])) | ||
| assert_expected(out.attentions, (torch.Tensor([[[[1.0, 0], [1.0, 0]]]]),)) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this line is redundant