Skip to content

Commit 7223af4

Browse files
committed
[FLAVA]Change ordering on contrastive loss initialization
ghstack-source-id: 39675e4 Pull Request resolved: #105
1 parent 0349375 commit 7223af4

File tree

3 files changed

+28
-27
lines changed

3 files changed

+28
-27
lines changed

test/models/flava/test_flava.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,27 +27,26 @@ def setUp(self):
2727

2828
@torch.no_grad()
2929
def test_forward_classification(self):
30-
flava = flava_model_for_classification(NUM_CLASSES)
3130
text = torch.randint(0, 30500, (2, 77), dtype=torch.long)
3231
image = torch.rand((2, 3, 224, 224))
33-
3432
labels = torch.randint(0, 2, (2,), dtype=torch.long)
33+
flava = flava_model_for_classification(NUM_CLASSES)
34+
flava.eval()
3535

3636
# Test multimodal scenario
3737
output = flava(image, text, "mm", labels)
38-
self.assertAlmostEqual(output.loss.item(), 0.9724, places=4)
38+
self.assertAlmostEqual(output.loss.item(), 0.7180, places=4)
3939

4040
# Test unimodal image scenario
4141
output = flava(image, text, "image", labels)
42-
self.assertAlmostEqual(output.loss.item(), 0.5453, places=4)
42+
self.assertAlmostEqual(output.loss.item(), 0.7020, places=4)
4343

4444
# Test unimodal text scenario
4545
output = flava(image, text, "text", labels)
46-
self.assertAlmostEqual(output.loss.item(), 0.7074, places=4)
46+
self.assertAlmostEqual(output.loss.item(), 0.6663, places=4)
4747

4848
@torch.no_grad()
4949
def test_forward_pretraining(self):
50-
flava = flava_model_for_pretraining()
5150
text = torch.randint(0, 30500, (2, 77), dtype=torch.long)
5251
image = torch.rand((2, 3, 224, 224))
5352
image_for_codebook = torch.rand(2, 3, 112, 112)
@@ -59,6 +58,8 @@ def test_forward_pretraining(self):
5958
mlm_labels[:, 1:3] = text[:, 1:3]
6059
itm_labels = torch.tensor((0, 1), dtype=torch.long)
6160

61+
flava = flava_model_for_pretraining()
62+
6263
output = flava(
6364
image=image,
6465
text=text,
@@ -79,7 +80,7 @@ def test_forward_pretraining(self):
7980
sum(
8081
value if value is not None else 0 for value in output.losses.values()
8182
).item(),
82-
20.4199,
83+
21.4791,
8384
places=4,
8485
)
8586

@@ -103,7 +104,7 @@ def test_forward_pretraining(self):
103104
sum(
104105
value if value is not None else 0 for value in output.losses.values()
105106
).item(),
106-
9.3403,
107+
8.9674,
107108
places=4,
108109
)
109110

@@ -128,7 +129,7 @@ def test_forward_pretraining(self):
128129
sum(
129130
value if value is not None else 0 for value in output.losses.values()
130131
).item(),
131-
10.8777,
132+
10.0305,
132133
places=4,
133134
)
134135

torchmultimodal/models/flava/flava_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,8 @@ def flava_model_for_pretraining(
185185
# TODO: Add parameters for loss here
186186
):
187187
model = flava_model(**flava_model_kwargs)
188-
189-
codebook = DalleVAEEncoder(image_size=codebook_image_size)
190188
losses = FLAVAPretrainingLoss()
189+
codebook = DalleVAEEncoder(image_size=codebook_image_size)
191190

192191
flava = FLAVAForPreTraining(
193192
model=model,
@@ -211,7 +210,7 @@ def flava_model_for_classification(
211210
loss_fn: Optional[Callable[..., Tensor]] = None,
212211
**flava_model_kwargs: Any,
213212
):
214-
model = flava_model(**flava_model_kwargs)
213+
215214
classifier = MLP(
216215
in_dim=classifier_in_dim,
217216
out_dim=num_classes,
@@ -220,6 +219,7 @@ def flava_model_for_classification(
220219
activation=classifier_activation,
221220
normalization=classifier_normalization,
222221
)
222+
model = flava_model(**flava_model_kwargs)
223223

224224
if loss_fn is None:
225225
loss_fn = nn.CrossEntropyLoss()

torchmultimodal/modules/losses/flava.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,21 @@ def forward(
380380
outputs = FLAVAPretrainingLossOutput()
381381
pos_mask = None
382382

383+
if (
384+
image_sequence is not None
385+
and text_sequence is not None
386+
and self.contrastive_loss_weight > 0
387+
):
388+
outputs.global_contrastive_output = self.contrastive_loss(
389+
image_sequence,
390+
text_sequence,
391+
pos_mask,
392+
)
393+
outputs.global_contrastive_output.loss *= self.contrastive_loss_weight
394+
outputs.losses.global_contrastive_loss = (
395+
outputs.global_contrastive_output.loss
396+
)
397+
383398
# Check multimodal_masked_sequence to make sure this is unimodal case
384399
# This specific case can though be backpropagated directly as MIM is independent of
385400
# text, but that is a research question :)
@@ -461,19 +476,4 @@ def forward(
461476
outputs.mmm_image_output.loss *= self.mmm_image_loss_weight
462477
outputs.losses.mmm_image_loss = outputs.mmm_image_output.loss
463478

464-
if (
465-
image_sequence is not None
466-
and text_sequence is not None
467-
and self.contrastive_loss_weight > 0
468-
):
469-
outputs.global_contrastive_output = self.contrastive_loss(
470-
image_sequence,
471-
text_sequence,
472-
pos_mask,
473-
)
474-
outputs.global_contrastive_output.loss *= self.contrastive_loss_weight
475-
outputs.losses.global_contrastive_loss = (
476-
outputs.global_contrastive_output.loss
477-
)
478-
479479
return outputs

0 commit comments

Comments
 (0)