forked from peisuke/ImplicitGeometricRegularization.pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
88 lines (67 loc) · 2.85 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import argparse
import os
import numpy as np
# import open3d as o3d
import torch
import torch.optim as optim
from dataset import Dataset
from utils import sample_fake
from utils import build_network
from utils import train
def load_data(filename, output_name):
pcd = o3d.io.read_point_cloud(filename)
pts = np.asarray(pcd.points)
size = pts.max(axis=0) - pts.min(axis=0)
pts = 2 * pts / size.max()
pts -= (pts.max(axis=0) + pts.min(axis=0)) / 2
np.save("{}.npy".format(output_name), pts)
return pts
def get_batchsize(iter):
scheduler = [
{'epoch': 10, 'batch_size': 32},
{'epoch': 20, 'batch_size': 64},
{'epoch': 30, 'batch_size': 128},
{'epoch': 40, 'batch_size': 256},
{'epoch': 50, 'batch_size': 512},
{'epoch': 100, 'batch_size': 1024}
]
for s in scheduler:
if iter < s['epoch']:
return s['batch_size']
return 2048
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--input', '-i', type=str, required=False, help='input filename (pcd, ply)')
parser.add_argument('--name', '-n', type=str, default='output', help='output model name')
parser.add_argument('--epochs', '-e', type=int, default=100, help='output model name')
parser.add_argument('--fast', action='store_true', help='batch size scheduling')
# args = parser.parse_args()
input_path = None
output_name = 'gargoyle'
nb_epochs = 1000
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# x = load_data(input_path, output_name)
x = np.load('input/gargoyle.npy')
# os.makedirs('output', exist_ok=True)
# pcd = o3d.geometry.PointCloud()
# pcd.points = o3d.utility.Vector3dVector(x)
# o3d.io.write_point_cloud("output/{}_pts.ply".format(output_name), pcd)
dataset = Dataset(x, knn=50)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
net = build_network(input_dim=3)
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
model = torch.nn.DataParallel(net)
net.to(device)
optimizer = optim.Adam(net.parameters(), lr=0.0001)
os.makedirs('models/{}'.format(output_name), exist_ok=True)
for itr in range(nb_epochs):
# if args.fast:
# batch_size = get_batchsize(itr)
# if batch_size != data_loader.batch_size:
# data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
loss = train(net, optimizer, data_loader, device)
print(itr, loss)
if itr % 100 == 0:
torch.save(net.state_dict(), 'models/' + output_name + '/model_{0:04d}.pth'.format(itr))
torch.save(net.state_dict(), 'models/' + output_name + '/model_final.pth')