-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_one.py
130 lines (114 loc) · 4.54 KB
/
train_one.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import numpy as np
import json
import os
import pprint
import lightning.pytorch as pl
from lightning.pytorch.callbacks import (
DeviceStatsMonitor,
LearningRateMonitor,
ModelCheckpoint,
)
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.profilers import PyTorchProfiler
import shutil
import torch
import torch.nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import wandb
from src.globals import default_config
from src.run import set_seed
from src.systems import GridCellSystem
from src.data import TrajectoryDataModule
# Torch settings.
torch.autograd.set_detect_anomaly(True)
# You are using a CUDA device ('A100 80GB PCIe') that has Tensor Cores. To properly utilize them, you
# should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision
# for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
# torch.set_float32_matmul_precision('medium')
# print('CUDA available: ', torch.cuda.is_available())
# print('CUDA device count: ', torch.cuda.device_count())
run = wandb.init(project="sorscher-2022-reproduction", config=default_config)
# Convert to a dictionary; otherwise, can't distribute because W&B
# config is not pickle-able.
wandb_config = dict(wandb.config)
# Convert "None" (type: str) to None (type: NoneType)
for key in ["accumulate_grad_batches", "gradient_clip_val", "learning_rate_scheduler"]:
if isinstance(wandb_config[key], str):
if wandb_config[key] == "None":
wandb_config[key] = None
# Create checkpoint directory for this run, and save the config to the directory.
run_checkpoint_dir = os.path.join("lightning_logs", wandb.run.id)
os.makedirs(run_checkpoint_dir)
with open(os.path.join(run_checkpoint_dir, "wandb_config.json"), "w") as fp:
json.dump(obj=wandb_config, fp=fp)
# Make sure we set all seeds for maximal reproducibility!
torch_generator = set_seed(seed=wandb_config["seed"])
wandb_logger = WandbLogger(experiment=run)
system = GridCellSystem(wandb_config=wandb_config, wandb_logger=wandb_logger)
# https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.LearningRateMonitor.html
lr_monitor_callback = LearningRateMonitor(logging_interval="step", log_momentum=True)
checkpoint_callback = ModelCheckpoint(
monitor="train/loss=total_loss",
save_top_k=1,
mode="min",
)
callbacks = [
lr_monitor_callback,
# checkpoint_callback, # Don't need to save these models.
]
# if torch.cuda.is_available():
# accelerator = 'cuda'
# devices = torch.cuda.device_count()
# callbacks.extend([
# # DeviceStatsMonitor()
# ])
# print('GPU available.')
# else:
# accelerator = 'auto'
# devices = None
# callbacks.extend([])
# print('No GPU available.')
trajectory_datamodule = TrajectoryDataModule(
wandb_config=wandb_config,
run_checkpoint_dir=run_checkpoint_dir,
torch_generator=torch_generator,
)
print("Created Trajectory Datamodule.")
# https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.trainer.trainer.Trainer.html
trainer = pl.Trainer(
accelerator="auto",
accumulate_grad_batches=wandb_config["accumulate_grad_batches"],
callbacks=callbacks,
check_val_every_n_epoch=wandb_config["check_val_every_n_epoch"],
default_root_dir=run_checkpoint_dir,
deterministic=True,
devices="auto",
# fast_dev_run=True,
fast_dev_run=False,
logger=wandb_logger,
log_every_n_steps=25,
# overfit_batches=1, # useful for debugging
gradient_clip_val=wandb_config["gradient_clip_val"],
max_epochs=wandb_config["n_epochs"],
num_sanity_val_steps=-1, # Runs all of validation before starting to train.
# profiler="simple", # Simplest profiler
# profiler="advanced", # More advanced profiler
# profiler=PyTorchProfiler(filename=), # PyTorch specific profiler
precision=wandb_config["precision"],
# track_grad_norm=2,
)
# .fit() needs to be called below for multiprocessing.
# See: https://github.com/Lightning-AI/lightning/issues/13039
# See: https://github.com/Lightning-AI/lightning/discussions/9201
# See: https://github.com/Lightning-AI/lightning/discussions/151
if __name__ == "__main__":
pp = pprint.PrettyPrinter(indent=4)
print("W&B Config:")
pp.pprint(wandb_config)
trainer.fit(
model=system,
datamodule=trajectory_datamodule,
)
# Delete the data after training finished to save disk space.
shutil.rmtree(os.path.join(run_checkpoint_dir, "data"))