Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace_packages = True
install_types = True

# TODO (T116951827): Remove after fixing FLAVA type check errors
exclude = models/flava/flava_model.py|modules/losses/flava.py
exclude = models/flava/flava_model.py|models/flava/flava_text_encoder.py|modules/losses/flava.py

[mypy-PIL.*]
ignore_missing_imports = True
Expand Down
5 changes: 5 additions & 0 deletions test/models/flava/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
File renamed without changes.
167 changes: 167 additions & 0 deletions test/models/flava/test_flava_image_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from test.test_utils import assert_expected, set_rng_seed
from torch import nn
from torchmultimodal.models.flava.flava_image_encoder import (
ImageEmbeddings,
ImageTransformer,
)
from torchmultimodal.modules.layers.transformer import FLAVATransformerEncoder


class TestFlavaImageEncoder(unittest.TestCase):
def setUp(self):
set_rng_seed(0)
torch.manual_seed(0)
Comment on lines +21 to +22
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this line is redundant

self.image_embedding = ImageEmbeddings(
image_size=2, patch_size=1, hidden_size=2
)

encoder = FLAVATransformerEncoder(
hidden_size=2,
num_attention_heads=1,
num_hidden_layers=1,
hidden_dropout_prob=0.0,
intermediate_size=1,
attention_probs_dropout_prob=0.0,
)
self.image_encoder = ImageTransformer(
embeddings=self.image_embedding,
encoder=encoder,
layernorm=nn.LayerNorm(2),
pooler=nn.Identity(),
)

def test_embedding(self):
input = torch.ones(2, 3, 2, 2)
out = self.image_embedding(input)
assert_expected(
out,
torch.Tensor(
[
[
[0.0000, 0.0000],
[0.0224, 0.0573],
[0.0224, 0.0573],
[0.0224, 0.0573],
[0.0224, 0.0573],
],
[
[0.0000, 0.0000],
[0.0224, 0.0573],
[0.0224, 0.0573],
[0.0224, 0.0573],
[0.0224, 0.0573],
],
]
),
atol=1e-4,
rtol=0,
)

def test_image_encoder(self):
input = torch.ones(2, 3, 2, 2)
out = self.image_encoder(input)
assert_expected(
out.last_hidden_state,
torch.Tensor(
[
[
[-0.0040, 0.0040],
[-0.9840, 0.9840],
[-0.9840, 0.9840],
[-0.9840, 0.9840],
[-0.9840, 0.9840],
],
[
[-0.0040, 0.0040],
[-0.9840, 0.9840],
[-0.9840, 0.9840],
[-0.9840, 0.9840],
[-0.9840, 0.9840],
],
]
),
atol=1e-4,
rtol=0,
)
assert_expected(out.pooler_output, out.last_hidden_state)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe a bit confusing to do the transitive thing here. Can you just set the expected result to a var and compare both last_hidden_state and pooler_output to that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isnt the transitive thing actually making it clear which values should line up

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, just personal preference I guess

assert_expected(
out.hidden_states,
(
torch.Tensor(
[
[
[0.0000, 0.0000],
[0.0224, 0.0573],
[0.0224, 0.0573],
[0.0224, 0.0573],
[0.0224, 0.0573],
],
[
[0.0000, 0.0000],
[0.0224, 0.0573],
[0.0224, 0.0573],
[0.0224, 0.0573],
[0.0224, 0.0573],
],
]
),
torch.Tensor(
[
[
[0.0008, 0.0008],
[0.0232, 0.0581],
[0.0232, 0.0581],
[0.0232, 0.0581],
[0.0232, 0.0581],
],
[
[0.0008, 0.0008],
[0.0232, 0.0581],
[0.0232, 0.0581],
[0.0232, 0.0581],
[0.0232, 0.0581],
],
]
),
),
atol=1e-4,
rtol=0,
)
assert_expected(
out.attentions,
(
torch.Tensor(
[
[
[
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.1999, 0.2000, 0.2000, 0.2000, 0.2000],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0.1999 due to rounding error? Maybe just make them all 0.2 for readability?

[0.1999, 0.2000, 0.2000, 0.2000, 0.2000],
[0.1999, 0.2000, 0.2000, 0.2000, 0.2000],
[0.1999, 0.2000, 0.2000, 0.2000, 0.2000],
]
],
[
[
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.1999, 0.2000, 0.2000, 0.2000, 0.2000],
[0.1999, 0.2000, 0.2000, 0.2000, 0.2000],
[0.1999, 0.2000, 0.2000, 0.2000, 0.2000],
[0.1999, 0.2000, 0.2000, 0.2000, 0.2000],
]
],
]
),
),
atol=1e-4,
rtol=0,
)
97 changes: 97 additions & 0 deletions test/models/flava/test_flava_text_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from test.test_utils import assert_expected, set_rng_seed
from torch import nn
from torchmultimodal.models.flava.flava_text_encoder import (
TextEmbeddings,
TextTransformer,
)
from torchmultimodal.modules.layers.transformer import FLAVATransformerEncoder


class TestFlavaTextEncoder(unittest.TestCase):
def setUp(self):
set_rng_seed(0)
self.text_embedding = TextEmbeddings(
hidden_size=2,
vocab_size=3,
max_position_embeddings=2,
hidden_dropout_prob=0,
)
emb_weights = torch.Tensor([[0, 1], [1, 0], [1, 1]])
self.text_embedding.word_embeddings = nn.Embedding.from_pretrained(emb_weights)
self.text_embedding.position_embeddings = nn.Embedding.from_pretrained(
emb_weights
)
self.text_embedding.token_type_embeddings = nn.Embedding.from_pretrained(
emb_weights
)

encoder = FLAVATransformerEncoder(
hidden_size=2,
num_attention_heads=1,
num_hidden_layers=1,
hidden_dropout_prob=0.0,
intermediate_size=1,
attention_probs_dropout_prob=0.0,
)
self.text_encoder = TextTransformer(
embeddings=self.text_embedding,
encoder=encoder,
layernorm=nn.LayerNorm(2),
pooler=nn.Identity(),
)

def test_embedding(self):
input_ids = torch.IntTensor([[0, 1]])
out = self.text_embedding(input_ids)
expected = torch.Tensor([[[1.0, -1.0], [-1.0, 1.0]]])
assert_expected(out, expected)

def test_text_transformer(self):
out = self.text_encoder(torch.IntTensor([[0, 1]]))

assert_expected(
out.last_hidden_state, torch.Tensor([[[1.0, -1.0], [-1.0, 1.0]]])
)

assert_expected(
out.hidden_states,
(
torch.Tensor([[[1.0000, -1.0000], [-1.0000, 1.0000]]]),
torch.Tensor([[[1.0008, -0.9994], [-0.9997, 1.0012]]]),
),
atol=1e-4,
rtol=0.0,
)

assert_expected(out.attentions, (torch.Tensor([[[[0, 1.0], [0.0, 1.0]]]]),))

def test_text_transformer_attn_mask(self):
input_ids = torch.IntTensor([[0, 1]])
attn_mask = torch.IntTensor([[1, 0]])
out = self.text_encoder(input_ids, attention_mask=attn_mask)

assert_expected(
out.last_hidden_state, torch.Tensor([[[1.0, -1.0], [-1.0, 1.0]]])
)

assert_expected(
out.hidden_states,
(
torch.Tensor([[[1.0, -1.0], [-1.0, 1.0]]]),
torch.Tensor([[[0.9997, -1.0012], [-1.0008, 0.9994]]]),
),
atol=1e-4,
rtol=0.0,
)

assert_expected(out.pooler_output, torch.Tensor([[[1.0, -1.0], [-1.0, 1.0]]]))
assert_expected(out.attentions, (torch.Tensor([[[[1.0, 0], [1.0, 0]]]]),))
Loading