-
Notifications
You must be signed in to change notification settings - Fork 1
/
test_graph_mnist_reg.cc
96 lines (85 loc) · 3.15 KB
/
test_graph_mnist_reg.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
/* tensor.cc for LITE
* Copyright (C) 2017 Mo Zhou <[email protected]>
* MIT License
*/
#include <iostream>
#include "leicht.hpp"
using namespace std;
unsigned int batchsize = 64;
double lr = 1e-3; // reference lr=1e-3
int maxiter = 1000;
int iepoch = 37800/batchsize;
int overfit = 10; // (DEBUG) let it overfit on howmany batches
int testevery = 100;
vector<double> validaccuhist;
vector<double> validlosshist;
int
main(void)
{
leicht_version();
cout << ">> Reading MNIST training dataset" << endl;
Tensor<double> trainImages (37800, 784);
trainImages.setName("trainImages");
leicht_hdf5_read("mnist.th.h5", "/train/images", 0, 0, 37800, 784, trainImages.data);
Tensor<double> trainLabels (37800, 1);
trainLabels.setName("trainLabels");
leicht_hdf5_read("mnist.th.h5", "/train/labels", 0, 0, 37800, 1, trainLabels.data);
cout << ">> Reading MNIST validation dataset" << endl;
Tensor<double> valImages(4200, 784); valImages.setName("valImages");
leicht_hdf5_read("mnist.th.h5", "/val/images", 0, 0, 4200, 784, valImages.data);
Tensor<double> valLabels(4200, 1); valLabels.setName("valLabels");
leicht_hdf5_read("mnist.th.h5", "/val/labels", 0, 0, 4200, 1, valLabels.data);
cout << ">> Initialize Network" << endl;
Graph<double> net (784, 1, 100);
net.name = "test_graph_mnist_cls2.cc";
net.addLayer("fc1", "Linear", "entryDataBlob", "fc1", 512);
net.addLayer("relu1", "Relu", "fc1", "fc1");
net.addLayer("fc2", "Linear", "fc1", "fc2", 192);
net.addLayer("relu2", "Relu", "fc2", "fc2");
net.addLayer("fc3", "Linear", "fc2", "fc3", 1);
net.addLayer("mse1", "MSELoss", "fc3", "mse1", "entryLabelBlob");
net.dump();
cout << ">> Start training" << endl;
for (int iteration = 0; iteration < 500; iteration++) {
leicht_bar_train(iteration);
// -- get batch
Tensor<double>* batchIm = new Tensor<double> (100, 784);
batchIm->copy(trainImages.data + (iteration%iepoch)*batchsize*784, batchsize*784);
batchIm->transpose_();
batchIm->scal_(1./255.);
net.getBlob("entryDataBlob", true)->value.copy(batchIm->data, 784*batchsize);
net.getBlob("entryLabelBlob", true)->value.copy(trainLabels.data + (iteration%iepoch)*batchsize*1, batchsize*1);
delete batchIm;
// -- forward
net.forward();
// -- zerograd
net.zeroGrad();
// -- backward
net.backward();
// -- report
net.report();
// -- update
net.update(lr, "SGD");
// -- test every
if ((iteration+1)%testevery==0) {
leicht_bar_val(iteration);
vector<double> accuracy;
vector<double> l;
for (int t = 0; t < 42; t++) {
// -- get batch
Tensor<double>* tbatchIm = new Tensor<double> (100, 784);
tbatchIm->copy(valImages.data + t*batchsize*784, batchsize*784);
tbatchIm->transpose_();
tbatchIm->scal_(1./255.);
net.getBlob("entryDataBlob", true)->value.copy(tbatchIm->data, 784*batchsize);
net.getBlob("entryLabelBlob", true)->value.copy(valLabels.data + t*batchsize*1, batchsize*1);
delete tbatchIm;
net.forward(); net.report();
}
}
}
// show history
for (auto i : validlosshist) cout << i << " "; cout << endl;
for (auto i : validaccuhist) cout << i << " "; cout << endl;
return 0;
}