-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
executable file
·99 lines (83 loc) · 2.86 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import hydra
import json
import logging
import neptune
from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader
from models import WGAN, JSGAN, CramerGAN
from trainer import Trainer
from utils import get_RICH
logger = logging.getLogger(__name__)
@hydra.main(config_path="config", config_name="config")
def main(config: DictConfig):
logger.info("Setting up logger")
project = neptune.init(
project_qualified_name=config["neptune"]["project_name"],
api_token="",
)
logger.info("Setting up experiment")
neptune_logger = project.create_experiment(
name=config["neptune"]["experiment_name"],
tags=OmegaConf.to_container(config["neptune"]["tags"]),
params=OmegaConf.to_container(config),
)
logger.info("Getting dataset")
input_size, dll_shape, train_dataset, valid_dataset, scaler = get_RICH(
config["data"]["particle"],
config["data"]["drop_weights"],
config["data"]["data_path"],
)
logger.info("Creating train dataloader")
train_loader = DataLoader(
dataset=train_dataset,
batch_size=config["batch_size"],
num_workers=config["num_workers"],
shuffle=True,
pin_memory=True,
)
logger.info("Creating validation dataloader")
validation_loader = DataLoader(
dataset=valid_dataset,
batch_size=config["batch_size"],
num_workers=config["num_workers"],
shuffle=True,
pin_memory=True,
)
logger.info("Creating GAN model")
if config['gan_type'] == 'js':
model = JSGAN(config)
elif config['gan_type'] == 'wgan':
model = WGAN(config)
elif config['gan_type'] == 'cramer':
model = CramerGAN(config)
else:
raise NameError("Unknown generator architecture: {}".format(config['gan_type']))
save_path = config["save_path"]
folder_name = f"{config['gan_type']}_{config['generator_architecture']}_{config['critic_architecture']}_{config['experiment_data']}"
if not os.path.exists(folder_name):
os.mkdir(folder_name)
save_path = os.path.join(save_path, folder_name)
logger.info("Creating trainer")
trainer = Trainer(
gan_model=model,
max_epoch=config["max_epoch"],
display_step=config["display_step"],
critic_step=config["critic_step"],
device=config["device"],
save_path=save_path,
freeze_generator=config["freeze_generator"],
neptune_logger=neptune_logger,
)
logger.info("Calculating inference time")
trainer.calculate_inference_time(
(1, config["generator"]["params"]["input_size"]), config["repetitions"]
)
logger.info("Starting train")
trainer.fit(
train_loader=train_loader,
validation_loader=validation_loader,
start=config["starting_epoch"],
)
if __name__ == "__main__":
main()