Skip to content

Commit 0349375

Browse files
committed
[FLAVA] Separate out text and image encoders
ghstack-source-id: 5212d85 Pull Request resolved: #102
1 parent 7b6b26c commit 0349375

File tree

7 files changed

+840
-529
lines changed

7 files changed

+840
-529
lines changed

test/models/flava/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
File renamed without changes.
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from test.test_utils import assert_expected, set_rng_seed
11+
from torch import nn
12+
from torchmultimodal.models.flava.flava_image_encoder import (
13+
ImageEmbeddings,
14+
ImageTransformer,
15+
)
16+
from torchmultimodal.modules.layers.transformer import FLAVATransformerEncoder
17+
18+
19+
class TestFlavaImageEncoder(unittest.TestCase):
20+
def setUp(self):
21+
set_rng_seed(0)
22+
torch.manual_seed(0)
23+
self.image_embedding = ImageEmbeddings(
24+
image_size=2, patch_size=1, hidden_size=2
25+
)
26+
27+
encoder = FLAVATransformerEncoder(
28+
hidden_size=2,
29+
num_attention_heads=1,
30+
num_hidden_layers=1,
31+
hidden_dropout_prob=0.0,
32+
intermediate_size=1,
33+
attention_probs_dropout_prob=0.0,
34+
)
35+
self.image_encoder = ImageTransformer(
36+
embeddings=self.image_embedding,
37+
encoder=encoder,
38+
layernorm=nn.LayerNorm(2),
39+
pooler=nn.Identity(),
40+
)
41+
42+
def test_embedding(self):
43+
input = torch.ones(2, 3, 2, 2)
44+
out = self.image_embedding(input)
45+
assert_expected(
46+
out,
47+
torch.Tensor(
48+
[
49+
[
50+
[0.0000, 0.0000],
51+
[0.0224, 0.0573],
52+
[0.0224, 0.0573],
53+
[0.0224, 0.0573],
54+
[0.0224, 0.0573],
55+
],
56+
[
57+
[0.0000, 0.0000],
58+
[0.0224, 0.0573],
59+
[0.0224, 0.0573],
60+
[0.0224, 0.0573],
61+
[0.0224, 0.0573],
62+
],
63+
]
64+
),
65+
atol=1e-4,
66+
rtol=0,
67+
)
68+
69+
def test_image_encoder(self):
70+
input = torch.ones(2, 3, 2, 2)
71+
out = self.image_encoder(input)
72+
assert_expected(
73+
out.last_hidden_state,
74+
torch.Tensor(
75+
[
76+
[
77+
[-0.0040, 0.0040],
78+
[-0.9840, 0.9840],
79+
[-0.9840, 0.9840],
80+
[-0.9840, 0.9840],
81+
[-0.9840, 0.9840],
82+
],
83+
[
84+
[-0.0040, 0.0040],
85+
[-0.9840, 0.9840],
86+
[-0.9840, 0.9840],
87+
[-0.9840, 0.9840],
88+
[-0.9840, 0.9840],
89+
],
90+
]
91+
),
92+
atol=1e-4,
93+
rtol=0,
94+
)
95+
assert_expected(out.pooler_output, out.last_hidden_state)
96+
assert_expected(
97+
out.hidden_states,
98+
(
99+
torch.Tensor(
100+
[
101+
[
102+
[0.0000, 0.0000],
103+
[0.0224, 0.0573],
104+
[0.0224, 0.0573],
105+
[0.0224, 0.0573],
106+
[0.0224, 0.0573],
107+
],
108+
[
109+
[0.0000, 0.0000],
110+
[0.0224, 0.0573],
111+
[0.0224, 0.0573],
112+
[0.0224, 0.0573],
113+
[0.0224, 0.0573],
114+
],
115+
]
116+
),
117+
torch.Tensor(
118+
[
119+
[
120+
[0.0008, 0.0008],
121+
[0.0232, 0.0581],
122+
[0.0232, 0.0581],
123+
[0.0232, 0.0581],
124+
[0.0232, 0.0581],
125+
],
126+
[
127+
[0.0008, 0.0008],
128+
[0.0232, 0.0581],
129+
[0.0232, 0.0581],
130+
[0.0232, 0.0581],
131+
[0.0232, 0.0581],
132+
],
133+
]
134+
),
135+
),
136+
atol=1e-4,
137+
rtol=0,
138+
)
139+
assert_expected(
140+
out.attentions,
141+
(
142+
torch.Tensor(
143+
[
144+
[
145+
[
146+
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
147+
[0.1999, 0.2000, 0.2000, 0.2000, 0.2000],
148+
[0.1999, 0.2000, 0.2000, 0.2000, 0.2000],
149+
[0.1999, 0.2000, 0.2000, 0.2000, 0.2000],
150+
[0.1999, 0.2000, 0.2000, 0.2000, 0.2000],
151+
]
152+
],
153+
[
154+
[
155+
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
156+
[0.1999, 0.2000, 0.2000, 0.2000, 0.2000],
157+
[0.1999, 0.2000, 0.2000, 0.2000, 0.2000],
158+
[0.1999, 0.2000, 0.2000, 0.2000, 0.2000],
159+
[0.1999, 0.2000, 0.2000, 0.2000, 0.2000],
160+
]
161+
],
162+
]
163+
),
164+
),
165+
atol=1e-4,
166+
rtol=0,
167+
)
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from test.test_utils import assert_expected, set_rng_seed
11+
from torch import nn
12+
from torchmultimodal.models.flava.flava_text_encoder import (
13+
TextEmbeddings,
14+
TextTransformer,
15+
)
16+
from torchmultimodal.modules.layers.transformer import FLAVATransformerEncoder
17+
18+
19+
class TestFlavaTextEncoder(unittest.TestCase):
20+
def setUp(self):
21+
set_rng_seed(0)
22+
self.text_embedding = TextEmbeddings(
23+
hidden_size=2,
24+
vocab_size=3,
25+
max_position_embeddings=2,
26+
hidden_dropout_prob=0,
27+
)
28+
emb_weights = torch.Tensor([[0, 1], [1, 0], [1, 1]])
29+
self.text_embedding.word_embeddings = nn.Embedding.from_pretrained(emb_weights)
30+
self.text_embedding.position_embeddings = nn.Embedding.from_pretrained(
31+
emb_weights
32+
)
33+
self.text_embedding.token_type_embeddings = nn.Embedding.from_pretrained(
34+
emb_weights
35+
)
36+
37+
encoder = FLAVATransformerEncoder(
38+
hidden_size=2,
39+
num_attention_heads=1,
40+
num_hidden_layers=1,
41+
hidden_dropout_prob=0.0,
42+
intermediate_size=1,
43+
attention_probs_dropout_prob=0.0,
44+
)
45+
self.text_encoder = TextTransformer(
46+
embeddings=self.text_embedding,
47+
encoder=encoder,
48+
layernorm=nn.LayerNorm(2),
49+
pooler=nn.Identity(),
50+
)
51+
52+
def test_embedding(self):
53+
input_ids = torch.IntTensor([[0, 1]])
54+
out = self.text_embedding(input_ids)
55+
expected = torch.Tensor([[[1.0, -1.0], [-1.0, 1.0]]])
56+
assert_expected(out, expected)
57+
58+
def test_text_transformer(self):
59+
out = self.text_encoder(torch.IntTensor([[0, 1]]))
60+
61+
assert_expected(
62+
out.last_hidden_state, torch.Tensor([[[1.0, -1.0], [-1.0, 1.0]]])
63+
)
64+
65+
assert_expected(
66+
out.hidden_states,
67+
(
68+
torch.Tensor([[[1.0000, -1.0000], [-1.0000, 1.0000]]]),
69+
torch.Tensor([[[1.0008, -0.9994], [-0.9997, 1.0012]]]),
70+
),
71+
atol=1e-4,
72+
rtol=0.0,
73+
)
74+
75+
assert_expected(out.attentions, (torch.Tensor([[[[0, 1.0], [0.0, 1.0]]]]),))
76+
77+
def test_text_transformer_attn_mask(self):
78+
input_ids = torch.IntTensor([[0, 1]])
79+
attn_mask = torch.IntTensor([[1, 0]])
80+
out = self.text_encoder(input_ids, attention_mask=attn_mask)
81+
82+
assert_expected(
83+
out.last_hidden_state, torch.Tensor([[[1.0, -1.0], [-1.0, 1.0]]])
84+
)
85+
86+
assert_expected(
87+
out.hidden_states,
88+
(
89+
torch.Tensor([[[1.0, -1.0], [-1.0, 1.0]]]),
90+
torch.Tensor([[[0.9997, -1.0012], [-1.0008, 0.9994]]]),
91+
),
92+
atol=1e-4,
93+
rtol=0.0,
94+
)
95+
96+
assert_expected(out.pooler_output, torch.Tensor([[[1.0, -1.0], [-1.0, 1.0]]]))
97+
assert_expected(out.attentions, (torch.Tensor([[[[1.0, 0], [1.0, 0]]]]),))

0 commit comments

Comments
 (0)