-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathtest_random_forest.py
70 lines (50 loc) · 1.85 KB
/
test_random_forest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import unittest
from sklearn.ensemble import RandomForestRegressor
from random_forest import RandomForest
from prettytable import PrettyTable
import sys
import io
class TestModelMethods(unittest.TestCase):
''' Test the methods of the random forest model wrapper.
Arguments:
unittest -- Python Unit Test Framework
'''
def setUp(self):
''' Test runner setup.
'''
self.rf = RandomForest()
self.rf.model_file = 'testModel.pth'
def testDataVariables(self):
''' Test whether feature and target variables are defined in the model.
'''
self.assertGreater(len(self.rf.features), 0)
self.assertGreater(len(self.rf.target), 0)
def testDataLoading(self):
''' Test the data loading.
'''
for dataset in ['train', 'val', 'test', 'full']:
self.assertGreater(len(self.rf.data[dataset]), 0)
self.assertGreater(len(self.rf.samples[dataset]), 0)
self.assertGreater(len(self.rf.labels[dataset]), 0)
def testSaveExistingModel(self):
''' Test the save method for a blank model.
'''
self.rf.model = RandomForestRegressor()
self.assertTrue(self.rf._saveModel())
def testSaveNonExistingModel(self):
''' Test the save method for a non existing model.
'''
self.rf.model = None
self.assertFalse(self.rf._saveModel())
def testLoadExisitngModel(self):
''' Test the load model for a previous saved model.
'''
self.rf._saveModel()
self.assertTrue(self.rf._loadModel())
def testLoadNonExisitngModel(self):
''' Test the load method for a non-existing model path.
'''
self.rf.model_file = 'non-existing/path.x'
self.assertFalse(self.rf._loadModel())
if __name__ == '__main__':
unittest.main()