diff --git a/examples/flava/finetune.py b/examples/flava/finetune.py index 6e61ee98..6d371920 100644 --- a/examples/flava/finetune.py +++ b/examples/flava/finetune.py @@ -60,7 +60,6 @@ def main(): callbacks=[ LearningRateMonitor(logging_interval="step"), ], - strategy="ddp", ) trainer.fit(model, datamodule=datamodule) trainer.validate(model, datamodule=datamodule) diff --git a/examples/flava/model.py b/examples/flava/model.py index c0ffe717..29baad71 100644 --- a/examples/flava/model.py +++ b/examples/flava/model.py @@ -8,7 +8,7 @@ import torch from pytorch_lightning import LightningModule -from torchmultimodal.models.flava import ( +from torchmultimodal.models.flava.flava_model import ( flava_model_for_classification, flava_model_for_pretraining, ) diff --git a/examples/flava/tools/convert_weights.py b/examples/flava/tools/convert_weights.py index 36f3774c..cde95c8a 100644 --- a/examples/flava/tools/convert_weights.py +++ b/examples/flava/tools/convert_weights.py @@ -7,7 +7,7 @@ import argparse import torch -from torchmultimodal.models.flava import flava_model_for_pretraining +from torchmultimodal.models.flava.flava_model import flava_model_for_pretraining KEY_REPLACEMENTS = { "image_encoder.module": "image_encoder", diff --git a/mypy.ini b/mypy.ini index 5a98cca5..149afdea 100644 --- a/mypy.ini +++ b/mypy.ini @@ -14,7 +14,7 @@ namespace_packages = True install_types = True # TODO (T116951827): Remove after fixing FLAVA type check errors -exclude = models/flava.py|modules/losses/flava.py +exclude = models/flava/flava_model.py|modules/losses/flava.py [mypy-PIL.*] ignore_missing_imports = True diff --git a/test/models/test_flava.py b/test/models/test_flava.py index 30be4f8c..e096b434 100644 --- a/test/models/test_flava.py +++ b/test/models/test_flava.py @@ -9,7 +9,7 @@ import torch from test.test_utils import assert_expected from torch import nn -from torchmultimodal.models.flava import ( +from torchmultimodal.models.flava.flava_model import ( flava_image_encoder, flava_model_for_classification, flava_model_for_pretraining, diff --git a/test/modules/layers/test_transformer.py b/test/modules/layers/test_transformer.py index 194cb9a9..e55298d9 100644 --- a/test/modules/layers/test_transformer.py +++ b/test/modules/layers/test_transformer.py @@ -8,7 +8,6 @@ import torch from test.test_utils import assert_expected, set_rng_seed -from torchmultimodal.models.flava import flava_image_encoder from torchmultimodal.modules.layers.transformer import ( FLAVASelfAttention, FLAVATransformerEncoder, @@ -31,11 +30,6 @@ def test_flava_self_attention_value_error(self): with self.assertRaises(ValueError): _ = FLAVASelfAttention(hidden_size=3, num_attention_heads=2) - def test_flava_transformer_without_embeddings_value_error(self): - with self.assertRaises(ValueError): - encoder = flava_image_encoder() - _ = encoder() - def test_flava_encoder_forward(self): output = self.encoder(self.test_input) diff --git a/torchmultimodal/models/flava/__init__.py b/torchmultimodal/models/flava/__init__.py new file mode 100644 index 00000000..2e41cd71 --- /dev/null +++ b/torchmultimodal/models/flava/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torchmultimodal/models/flava.py b/torchmultimodal/models/flava/flava_model.py similarity index 100% rename from torchmultimodal/models/flava.py rename to torchmultimodal/models/flava/flava_model.py