-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathcensus.py
105 lines (91 loc) · 3.63 KB
/
census.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import logging
from pathlib import Path
import hydra
import torch
import wandb
from omegaconf import DictConfig, OmegaConf
from src.datasets.census import TestCensusDataModule
from src.models.factory.cosmos.upsampler import Upsampler
from src.models.factory.mlp import MultiTaskMLP
from src.models.factory.phn.phn_wrappers import HyperModel
from src.models.factory.rotograd import RotogradWrapper
from src.utils import set_seed
from src.utils._selectors import get_callbacks, get_ensemble_model, get_optimizer, get_trainer
from src.utils.callbacks.auto_lambda_callback import AutoLambdaCallback
from src.utils.logging_utils import initialize_wandb, install_logging
from src.utils.losses import MultiTaskCrossEntropyLoss
@hydra.main(config_path="configs/experiment/census", config_name="census")
def my_app(config: DictConfig) -> None:
install_logging()
logging.info(OmegaConf.to_yaml(config))
initialize_wandb(config)
set_seed(config.seed)
dm = TestCensusDataModule(
**(dict() if config.data.root is None else dict(root=Path(config.data.root))),
batch_size=config.data.batch_size,
num_workers=config.data.num_workers,
income=config.data.income,
age=config.data.age,
education=config.data.education,
never_married=config.data.never_married,
)
logging.info(f"I am using the following benchmark {dm.name}")
if config.method.name == "phn":
model = HyperModel(dm.name)
elif config.method.name == "cosmos":
model = MultiTaskMLP(
in_features=dm.num_features + 2,
num_tasks=dm.num_tasks,
encoder_specs=config.model.encoder_specs,
decoder_specs=config.model.decoder_specs,
)
elif config.method.name == "rotograd":
m = MultiTaskMLP(
in_features=dm.num_features,
num_tasks=dm.num_tasks,
encoder_specs=config.model.encoder_specs,
decoder_specs=config.model.decoder_specs,
)
backbone = m.encoder
head1, head2 = m.decoders[0], m.decoders[1]
model = RotogradWrapper(backbone=backbone, heads=[head1, head2], latent_size=256)
else:
model = MultiTaskMLP(
in_features=dm.num_features,
num_tasks=dm.num_tasks,
encoder_specs=config.model.encoder_specs,
decoder_specs=config.model.decoder_specs,
)
logging.info(model)
if config.method.name == "pamal":
model = get_ensemble_model(model, dm.num_tasks, config)
elif config.method.name == "cosmos":
model = Upsampler(dm.num_tasks, model, input_dim=dm.input_dims)
param_groups = model.parameters()
optimizer = get_optimizer(config, param_groups)
if config.method.name == "rotograd":
optimizer = torch.optim.Adam(
[{"params": m.parameters()} for m in [backbone, head1, head2]]
+ [{"params": model.parameters(), "lr": config.optimizer.lr * 0.1}],
lr=config.optimizer.lr,
)
callbacks = get_callbacks(config, dm.num_tasks)
if config.method.name == "autol":
callbacks.append(AutoLambdaCallback(config.method.meta_lr))
trainer_kwargs = dict(
model=model,
benchmark=dm,
optimizer=optimizer,
gpu=0,
callbacks=callbacks,
loss_fn=MultiTaskCrossEntropyLoss(),
)
trainer = get_trainer(config, trainer_kwargs, dm.num_tasks, model)
trainer.fit(epochs=config.training.epochs)
if config.method.name == "pamal":
trainer.predict_interpolations(dm.test_dataloader())
else:
trainer.predict(test_loader=dm.test_dataloader())
wandb.finish()
if __name__ == "__main__":
my_app()