Skip to content

Commit 440029d

Browse files
committed
[FLAVA] Move masked prediction head to flava_for_pretraining
ghstack-source-id: 0ee0911 Pull Request resolved: #195
1 parent a586c88 commit 440029d

File tree

2 files changed

+115
-84
lines changed

2 files changed

+115
-84
lines changed

torchmultimodal/models/flava/flava_model.py

Lines changed: 91 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060

6161

6262
FLAVA_FOR_PRETRAINED_MAPPING = {
63-
"flava_full": "https://download.pytorch.org/models/multimodal/flava/flava_for_pretraining_cl_itm.pt"
63+
"flava_full": "https://download.pytorch.org/models/multimodal/flava/flava_mp.pt"
6464
}
6565

6666

@@ -309,6 +309,45 @@ def forward(self, hidden_states: Tensor):
309309
return logits
310310

311311

312+
class MaskedPredictionHead(nn.Module):
313+
def __init__(
314+
self,
315+
hidden_size: int = 768,
316+
vocab_size: int = 30522,
317+
transform_act_fn: Callable[[Tensor], Tensor] = nn.functional.gelu,
318+
layer_norm_eps: float = 1e-5,
319+
use_fp32_layer_norm: bool = True,
320+
**kwargs: Any,
321+
):
322+
super().__init__()
323+
324+
self.dense = nn.Linear(hidden_size, hidden_size)
325+
self.transform_act_fn = transform_act_fn
326+
327+
self.layer_norm: nn.LayerNorm
328+
if use_fp32_layer_norm:
329+
self.layer_norm = Fp32LayerNorm(hidden_size, eps=layer_norm_eps)
330+
else:
331+
self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
332+
333+
# The output weights are the same as the input embeddings, but there is
334+
# an output-only bias for each token.
335+
self.decoder = nn.Linear(hidden_size, vocab_size, bias=False)
336+
337+
self.bias = nn.Parameter(torch.zeros(vocab_size))
338+
339+
# Need a link between the two variables so that the bias is
340+
# correctly resized with `resize_token_embeddings`
341+
self.decoder.bias = self.bias
342+
343+
def forward(self, hidden_states: Tensor) -> Tensor:
344+
hidden_states = self.dense(hidden_states)
345+
hidden_states = self.transform_act_fn(hidden_states)
346+
hidden_states = self.layer_norm(hidden_states)
347+
hidden_states = self.decoder(hidden_states)
348+
return hidden_states
349+
350+
312351
class FLAVAForPreTraining(nn.Module, PretrainedMixin):
313352
# TODOs:
314353
# 1. Expose logit scale
@@ -320,12 +359,20 @@ def __init__(
320359
image_codebook: nn.Module,
321360
loss: nn.Module,
322361
itm_head: nn.Module,
362+
mlm_head: nn.Module,
363+
mim_head: nn.Module,
364+
mmm_mlm_head: nn.Module,
365+
mmm_mim_head: nn.Module,
323366
):
324367
super().__init__()
325368
self.model = model
326369
self.image_codebook = image_codebook
327370
self.loss = loss
328371
self.itm_head = itm_head
372+
self.mlm_head = mlm_head
373+
self.mim_head = mim_head
374+
self.mmm_mlm_head = mmm_mlm_head
375+
self.mmm_mim_head = mmm_mim_head
329376

330377
def encode_image(
331378
self,
@@ -378,6 +425,25 @@ def forward(
378425
if multimodal_masked_sequence is not None:
379426
itm_logits = self.itm_head(multimodal_masked_sequence)
380427

428+
image_masked_sequence = flava_output.image_masked.last_hidden_state
429+
text_masked_sequence = flava_output.text_masked.last_hidden_state
430+
mlm_head_output = (
431+
mim_head_output
432+
) = mmm_mlm_head_output = mmm_mim_head_output = None
433+
434+
if image_masked_sequence is not None and multimodal_masked_sequence is None:
435+
mim_head_output = self.mim_head(image_masked_sequence)
436+
if text_masked_sequence is not None and multimodal_masked_sequence is None:
437+
mlm_head_output = self.mlm_head(text_masked_sequence)
438+
if multimodal_masked_sequence is not None:
439+
start_index = -(text_masked_sequence.size(1))
440+
mmm_text_sequence = multimodal_masked_sequence[:, start_index:, :]
441+
mmm_mlm_head_output = self.mmm_mlm_head(mmm_text_sequence)
442+
if multimodal_masked_sequence is not None:
443+
total_indices = image_masked_sequence.size(1) - 1
444+
mmm_image_sequence = multimodal_masked_sequence[:, 2 : 2 + total_indices, :]
445+
mmm_mim_head_output = self.mmm_mim_head(mmm_image_sequence)
446+
381447
return self.loss(
382448
image_sequence=flava_output.image.last_hidden_state,
383449
text_sequence=flava_output.text.last_hidden_state,
@@ -393,6 +459,10 @@ def forward(
393459
projected_image_embeddings=flava_output.projected_image_embeddings,
394460
projected_text_embeddings=flava_output.projected_text_embeddings,
395461
itm_logits=itm_logits,
462+
mlm_head_output=mlm_head_output,
463+
mim_head_output=mim_head_output,
464+
mmm_mlm_head_output=mmm_mlm_head_output,
465+
mmm_mim_head_output=mmm_mim_head_output,
396466
)
397467

398468

@@ -544,17 +614,36 @@ def flava_model(
544614
def flava_model_for_pretraining(
545615
codebook_image_size: int = 112,
546616
pretrained_model_key: Optional[str] = None,
617+
image_vocab_size: int = 8192,
547618
**flava_model_kwargs: Any,
548619
# TODO: Add parameters for loss here
549620
) -> FLAVAForPreTraining:
550621
model = flava_model(**flava_model_kwargs)
551622
hidden_size = flava_model_kwargs.get("hidden_size") or 768
623+
text_vocab_size = flava_model_kwargs.get("vocab_size") or 30522
552624
itm_head = ITMHead(hidden_size)
625+
mlm_head = MaskedPredictionHead(hidden_size=hidden_size, vocab_size=text_vocab_size)
626+
mim_head = MaskedPredictionHead(
627+
hidden_size=hidden_size, vocab_size=image_vocab_size
628+
)
629+
mmm_mlm_head = MaskedPredictionHead(
630+
hidden_size=hidden_size, vocab_size=text_vocab_size
631+
)
632+
mmm_mim_head = MaskedPredictionHead(
633+
hidden_size=hidden_size, vocab_size=image_vocab_size
634+
)
553635
losses = FLAVAPretrainingLoss()
554636
codebook = DalleVAEEncoder(image_size=codebook_image_size)
555637

556638
flava = FLAVAForPreTraining(
557-
model=model, image_codebook=codebook, loss=losses, itm_head=itm_head
639+
model=model,
640+
image_codebook=codebook,
641+
loss=losses,
642+
itm_head=itm_head,
643+
mlm_head=mlm_head,
644+
mim_head=mim_head,
645+
mmm_mlm_head=mmm_mlm_head,
646+
mmm_mim_head=mmm_mim_head,
558647
)
559648

560649
if pretrained_model_key is not None:

torchmultimodal/modules/losses/flava.py

Lines changed: 24 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
import torch
1313
from torch import nn, Tensor
14-
from torchmultimodal.modules.layers.normalizations import Fp32LayerNorm
1514
from torchmultimodal.modules.losses.contrastive_loss_with_temperature import (
1615
contrastive_loss_with_temperature,
1716
ContrastiveLossOutput,
@@ -130,45 +129,6 @@ def forward(
130129
return ITMLossOutput(logits=scores, loss=loss)
131130

132131

133-
class MaskedPredictionHead(nn.Module):
134-
def __init__(
135-
self,
136-
hidden_size: int = 768,
137-
vocab_size: int = 30522,
138-
transform_act_fn: Callable[[Tensor], Tensor] = nn.functional.gelu,
139-
layer_norm_eps: float = 1e-5,
140-
use_fp32_layer_norm: bool = True,
141-
**kwargs: Any,
142-
):
143-
super().__init__()
144-
145-
self.dense = nn.Linear(hidden_size, hidden_size)
146-
self.transform_act_fn = transform_act_fn
147-
148-
self.layer_norm: nn.LayerNorm
149-
if use_fp32_layer_norm:
150-
self.layer_norm = Fp32LayerNorm(hidden_size, eps=layer_norm_eps)
151-
else:
152-
self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
153-
154-
# The output weights are the same as the input embeddings, but there is
155-
# an output-only bias for each token.
156-
self.decoder = nn.Linear(hidden_size, vocab_size, bias=False)
157-
158-
self.bias = nn.Parameter(torch.zeros(vocab_size))
159-
160-
# Need a link between the two variables so that the bias is
161-
# correctly resized with `resize_token_embeddings`
162-
self.decoder.bias = self.bias
163-
164-
def forward(self, hidden_states: Tensor) -> Tensor:
165-
hidden_states = self.dense(hidden_states)
166-
hidden_states = self.transform_act_fn(hidden_states)
167-
hidden_states = self.layer_norm(hidden_states)
168-
hidden_states = self.decoder(hidden_states)
169-
return hidden_states
170-
171-
172132
class MaskedPredictionLoss(nn.Module):
173133
def __init__(
174134
self,
@@ -181,36 +141,29 @@ def __init__(
181141
**kwargs: Any,
182142
):
183143
super().__init__()
184-
185-
self.cls = MaskedPredictionHead(
186-
hidden_size=hidden_size,
187-
vocab_size=vocab_size,
188-
transform_act_fn=transform_act_fn,
189-
layer_norm_eps=layer_norm_eps,
190-
)
191144
self.ignore_index = ignore_index
192145
self.vocab_size = vocab_size
193146
self.ce_loss = nn.CrossEntropyLoss(ignore_index=ignore_index)
194147
self.ignore_nan = ignore_nan
195148

196149
def forward(
197-
self, hidden_states: Tensor, masked_labels: Optional[Tensor] = None
150+
self,
151+
prediction: Tensor,
152+
masked_labels: Optional[Tensor] = None,
153+
pos_mask: Optional[Tensor] = None,
198154
) -> MaskedPredictionLossOutput:
199-
if self.training:
200-
assert_labels_are_present(masked_labels, "masked labels")
201-
202-
if masked_labels is not None:
203-
masked_tokens = masked_labels.ne(self.ignore_index)
204-
masked_labels = masked_labels[masked_tokens]
205-
sequence_output = hidden_states[masked_tokens, :]
206-
else:
207-
sequence_output = hidden_states
208155

209-
prediction = self.cls(sequence_output)
156+
if pos_mask is not None:
157+
masked_labels = masked_labels[pos_mask]
158+
masked_tokens = masked_labels.ne(self.ignore_index)
159+
masked_labels = masked_labels[masked_tokens]
210160

211161
if masked_labels is None:
212162
masked_loss = prediction.sum() * 0
213163
else:
164+
if pos_mask is not None:
165+
prediction = prediction[pos_mask]
166+
prediction = prediction[masked_tokens, :]
214167
masked_loss = self.ce_loss(
215168
prediction.view(-1, self.vocab_size),
216169
masked_labels.view(-1),
@@ -371,6 +324,10 @@ def forward(
371324
projected_image_embeddings: Optional[Tensor] = None,
372325
projected_text_embeddings: Optional[Tensor] = None,
373326
itm_logits: Optional[Tensor] = None,
327+
mlm_head_output: Optional[Tensor] = None,
328+
mim_head_output: Optional[Tensor] = None,
329+
mmm_mlm_head_output: Optional[Tensor] = None,
330+
mmm_mim_head_output: Optional[Tensor] = None,
374331
) -> FLAVAPretrainingLossOutput:
375332
outputs = FLAVAPretrainingLossOutput()
376333
pos_mask = None
@@ -380,28 +337,28 @@ def forward(
380337
# text, but that is a research question :)
381338

382339
if (
383-
image_masked_sequence is not None
340+
mim_head_output is not None
384341
and self.mim_weight > 0
385342
and multimodal_masked_sequence is None
386343
):
387344
# Remove CLS token from image_masked_sequence
388-
389345
start_index = -mim_labels.size(1) if mim_labels is not None else 1
390346
outputs.mim_output = self.mim_loss(
391-
image_masked_sequence[:, start_index:, :], mim_labels
347+
mim_head_output[:, start_index:, :], mim_labels
392348
)
393349
outputs.mim_output.loss *= self.mim_weight
394350
outputs.losses.mim_loss = outputs.mim_output.loss
395351

396352
# Check multimodal_masked_sequence to make sure this is unimodal case
353+
397354
if (
398-
text_masked_sequence is not None
355+
mlm_head_output is not None
399356
and self.mlm_weight > 0
400357
and multimodal_masked_sequence is None
401358
):
402359
start_index = -mlm_labels.size(1) if mlm_labels is not None else 1
403360
outputs.mlm_output = self.mlm_loss(
404-
text_masked_sequence[:, start_index:, :], mlm_labels
361+
mlm_head_output[:, start_index:, :], mlm_labels
405362
)
406363
outputs.mlm_output.loss *= self.mlm_weight
407364
outputs.losses.mlm_loss = outputs.mlm_output.loss
@@ -422,38 +379,23 @@ def forward(
422379
outputs.itm_output.loss *= self.itm_loss_weight
423380
outputs.losses.itm_loss = outputs.itm_output.loss
424381

425-
multimodal_masked_sequence = multimodal_masked_sequence[pos_mask]
426-
if mlm_labels is not None:
427-
mlm_labels = mlm_labels[pos_mask]
428-
if mim_labels is not None:
429-
mim_labels = mim_labels[pos_mask]
430-
431-
if multimodal_masked_sequence is not None and self.mmm_text_loss_weight > 0:
432-
start_index = (
433-
-mlm_labels.size(1)
434-
if mlm_labels is not None
435-
else -(text_masked_sequence.size(1) - 1)
436-
)
437-
sequence_for_text = multimodal_masked_sequence[:, start_index:, :]
382+
if mmm_mlm_head_output is not None and self.mmm_text_loss_weight > 0:
438383
outputs.mmm_text_output = self.mmm_loss.mlm(
439-
sequence_for_text,
440-
mlm_labels,
384+
mmm_mlm_head_output, mlm_labels, pos_mask
441385
) # type: ignore
442386
outputs.mmm_text_output.loss *= self.mmm_text_loss_weight
443387
outputs.losses.mmm_text_loss = outputs.mmm_text_output.loss
444388

445-
if multimodal_masked_sequence is not None and self.mmm_image_loss_weight > 0:
389+
if mmm_mim_head_output is not None and self.mmm_image_loss_weight > 0:
446390
# Starts from 2 because of 2 CLS, one for multimodal encoder and one
447391
# that comes from image encoder.
448392
total_indices = (
449393
mim_labels.size(1)
450394
if mlm_labels is not None
451395
else (image_masked_sequence.size(1) - 1)
452396
)
453-
sequence_for_image = multimodal_masked_sequence[:, 2 : 2 + total_indices, :]
454397
outputs.mmm_image_output = self.mmm_loss.mim(
455-
sequence_for_image,
456-
mim_labels,
398+
mmm_mim_head_output, mim_labels, pos_mask
457399
) # type: ignore
458400
outputs.mmm_image_output.loss *= self.mmm_image_loss_weight
459401
outputs.losses.mmm_image_loss = outputs.mmm_image_output.loss

0 commit comments

Comments
 (0)