-
Notifications
You must be signed in to change notification settings - Fork 40
/
Copy pathlr_pytorch_test.py
47 lines (36 loc) · 1.59 KB
/
lr_pytorch_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import unittest
import tempfile
import torch
from torch.nn import BCELoss as Loss
from torch.optim import SGD as Opt
from lr_pytorch import step, SimpleLogreg
from numpy import array
class TestPyTorchLR(unittest.TestCase):
def setUp(self):
torch.manual_seed(1701)
self.raw_data = array([[1., 4., 3., 1., 0.],
[0., 0., 1., 3., 4.]])
self.data = torch.from_numpy(self.raw_data).float()
labels = array([[1], [0]])
self.labels = torch.from_numpy(labels).float()
self.model = SimpleLogreg(5)
with torch.no_grad():
self.model.linear.weight.fill_(0)
self.model.linear.weight[0,0] = 1
self.model.linear.weight[0,4] = -1
def test_forward(self):
self.assertAlmostEqual(0.667, float(self.model.forward(self.data)[0]), 3)
self.assertAlmostEqual(0.0135, float(self.model.forward(self.data)[1]), 3)
def test_step(self):
optimizer = Opt(self.model.parameters(), lr=0.1)
criterion = Loss()
step(0, 0, self.model, optimizer, criterion, self.data, self.labels)
weight, bias = list(self.model.parameters())
self.assertAlmostEqual(float(weight[0][0]), 1.0166, 3)
self.assertAlmostEqual(float(weight[0][1]), 0.0666, 3)
self.assertAlmostEqual(float(weight[0][2]), 0.0493, 3)
self.assertAlmostEqual(float(weight[0][3]), 0.0145, 3)
self.assertAlmostEqual(float(weight[0][4]), -1.0027, 3)
self.assertAlmostEqual(float(bias[0]), -0.289, 3)
if __name__ == '__main__':
unittest.main()