forked from ecovision-uzh/sat-sinr
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
128 lines (111 loc) · 3.29 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from embedders import get_embedder
from models import *
from dataset import *
import wandb
import hydra
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
import warnings
warnings.filterwarnings("ignore")
def get_logger(params, tag=""):
"""
Instantiates Weights and Biases logger
"""
wandb.finish()
name = params.model
if params.model == "sinr" or params.model == "log_reg":
name += " " + params.dataset.predictors
elif "sat" in params.model:
name += " " + params.embedder
if params.validate:
name += " val"
name += " " + tag
logger = hydra.utils.instantiate(
{
"_target_": "pytorch_lightning.loggers.WandbLogger",
"name": name,
"save_dir": params.local.logs_dir_path,
"project": "sinr_on_glc23",
}
)
return logger
def train_model(
params,
dataset,
train_loader,
val_loader,
provide_model=None,
logger=None,
validate=False,
):
"""
Instantiates model, defines which epoch to save as checkpoint, and trains
"""
torch.set_float32_matmul_precision("medium")
if not provide_model:
if params.model == "sinr" or params.model == "log_reg":
model = SINR(params, dataset)
elif "sat" in params.model:
model = SAT_SINR(params, dataset, get_embedder(params))
else:
model = provide_model
checkpoint_callback = ModelCheckpoint(
save_top_k=1,
monitor="val_loss",
mode="min",
dirpath=params.local.cp_dir_path,
filename=logger._name + "{val_loss:.4f}",
)
trainer = pl.Trainer(
max_epochs=params.epochs,
accelerator=("gpu" if params.local.gpu else "cpu"),
devices=1,
precision="16-mixed",
logger=logger,
log_every_n_steps=50,
callbacks=[checkpoint_callback],
)
if validate:
trainer.validate(model=model, dataloaders=[val_loader])
else:
trainer.fit(model, train_loader, val_loader)
def load_cp(params, dataset):
"""Loads checkpoint."""
if params.model == "sinr" or params.model == "log_reg":
model = SINR.load_from_checkpoint(
params.checkpoint, params=params, dataset=dataset
)
elif "sat" in params.model:
model = SAT_SINR.load_from_checkpoint(
params.checkpoint,
params=params,
dataset=dataset,
sent2_net=get_embedder(params),
)
return model
@hydra.main(version_base=None, config_path="config", config_name="base_config.yaml")
def main(params):
"""main funct."""
dataset, train_loader, val_loader = create_datasets(params)
logger = get_logger(params, tag=params.tag)
if params.checkpoint != "None":
model = load_cp(params, dataset)
train_model(
params,
dataset,
train_loader,
val_loader,
provide_model=model,
logger=logger,
validate=params.validate,
)
else:
train_model(params, dataset, train_loader, val_loader, logger=logger)
wandb.finish()
if __name__ == "__main__":
try:
main()
except:
# In case of crash make sure to still finish logging everything
wandb.finish()