-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathmain_trifeatures.py
More file actions
96 lines (75 loc) · 3.26 KB
/
main_trifeatures.py
File metadata and controls
96 lines (75 loc) · 3.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from omegaconf import DictConfig
import hydra
from hydra.utils import instantiate
import numpy as np
import os
import torch
import torch.nn.parallel
import torch.optim
import torch.utils.data
from pytorch_lightning.loggers import TensorBoardLogger
from evaluation.linear_probe import LinearProbingCallback
@hydra.main(version_base=None, config_name="train_trifeatures", config_path="./configs")
def main(cfg: DictConfig):
"""Training/test of Multi-Modal models on synthetic toy data (bimodal trifeatures) with
controllable attributes (shape, color, texture).
Models currently implemented are:
- CoMM [ours!]
- CLIP
- CrossSelf
"""
# fix the seed for repro
torch.manual_seed(cfg.seed)
np.random.seed(cfg.seed)
# create model + save hyper-parameters
kwargs = dict()
if cfg.model.name== "CoMM":
kwargs["encoder"] = {
"encoders": instantiate(cfg.model.encoders),
"input_adapters": instantiate(cfg.model.adapters)}
if cfg.model.name == "CLIP":
encoders = instantiate(cfg.model.encoders)
kwargs["visual"], kwargs["language"] = encoders[0], encoders[1]
kwargs["image_projection"] = instantiate(cfg.model.clip_image_projection)
kwargs["text_projection"] = instantiate(cfg.model.clip_text_projection)
if cfg.model.name == "CrossSelf":
encoders = instantiate(cfg.model.encoders)
kwargs["enc1"] = encoders[0]
kwargs["enc2"] = encoders[1]
kwargs["head1"] = instantiate(cfg.model.visual_projection)
kwargs["head2"] = instantiate(cfg.model.visual_projection)
model = instantiate(cfg.model.model, optim_kwargs=cfg.optim, **kwargs)
model.save_hyperparameters(cfg)
# Data loading code
data_module = instantiate(cfg.data.data_module, model=cfg.model.name)
# Linear probing on each tasks from BimodalTrifeatures
downstream_names = ["share", "unique1", "unique2", "synergy"]
downstream_data_modules = [instantiate(cfg.data.data_module, model="Sup", biased=False, task=t)
for t in downstream_names]
# Trainer + fit
trainer = instantiate(
cfg.trainer,
default_root_dir=build_root_dir(cfg),
logger=[TensorBoardLogger(build_root_dir(cfg), name="logs")],
callbacks=[LinearProbingCallback(downstream_data_modules,
names=downstream_names,
val_loaders=False)]
)
if cfg.mode == "train":
trainer.fit(model, datamodule=data_module)
else:
trainer.test(model, datamodule=data_module, ckpt_path=getattr(cfg, "ckpt_path", None))
def build_root_dir(cfg: DictConfig):
# set directory for logs and checkpoints
root_dir = os.path.join(cfg.trainer.default_root_dir, cfg.model.name, "bimodal_trifeatures")
# modify `root_dir` if in test mode to match pre-trained model's path
if cfg.mode == "test":
if cfg.ckpt_path is None:
print(UserWarning("`ckpt_path` is not set during testing."))
else:
root_dir = os.path.join(os.path.dirname(cfg.ckpt_path), "test")
if getattr(cfg, "exp_name", None) is not None:
root_dir = os.path.join(root_dir, cfg.exp_name)
return root_dir
if __name__ == '__main__':
main()