-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathtrain.py
63 lines (54 loc) · 2.05 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
from source.trainer import EDGSTrainer
from source.utils_aux import set_seed
import omegaconf
import wandb
import hydra
from argparse import Namespace
from omegaconf import OmegaConf
@hydra.main(config_path="configs", config_name="train", version_base="1.2")
def main(cfg: omegaconf.DictConfig):
_ = wandb.init(entity=cfg.wandb.entity,
project=cfg.wandb.project,
config=omegaconf.OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True),
tags=[cfg.wandb.tag],
name = cfg.wandb.name,
mode = cfg.wandb.mode)
omegaconf.OmegaConf.resolve(cfg)
set_seed(cfg.seed)
# Init output folder
print("Output folder: {}".format(cfg.gs.dataset.model_path))
os.makedirs(cfg.gs.dataset.model_path, exist_ok=True)
with open(os.path.join(cfg.gs.dataset.model_path, "cfg_args"), 'w') as cfg_log_f:
params = {
"sh_degree": 3,
"source_path": cfg.gs.dataset.source_path,
"model_path": cfg.gs.dataset.model_path,
"images": cfg.gs.dataset.images,
"depths": "",
"resolution": -1,
"_white_background": cfg.gs.dataset.white_background,
"train_test_exp": False,
"data_device": cfg.gs.dataset.data_device,
"eval": False,
"convert_SHs_python": False,
"compute_cov3D_python": False,
"debug": False,
"antialiasing": False
}
cfg_log_f.write(str(Namespace(**params)))
# Init both agents
gs = hydra.utils.instantiate(cfg.gs)
# Init trainer and launch training
trainer = EDGSTrainer(GS=gs,
training_config=cfg.gs.opt,
device=cfg.device)
trainer.load_checkpoints(cfg.load)
trainer.timer.start()
trainer.init_with_corr(cfg.init_wC)
trainer.train(cfg.train)
# All done
wandb.finish()
print("\nTraining complete.")
if __name__ == "__main__":
main()