1414from theano import tensor as T
1515
1616from pylearn2 .models .mlp import MLP , Sigmoid
17+ from pylearn2 .expr .nnet import arg_of_sigmoid
1718from pylearn2 .expr .nnet import pseudoinverse_softmax_numpy
1819from pylearn2 .expr .nnet import softmax_numpy
1920from pylearn2 .expr .nnet import softmax_ratio
2021from pylearn2 .expr .nnet import compute_recall
2122from pylearn2 .expr .nnet import kl
22- from pylearn2 .expr .nnet import elemwise_kl
23+ from pylearn2 .expr .nnet import elemwise_kl
2324from pylearn2 .utils import sharedX
2425
2526
@@ -83,7 +84,7 @@ def test_kl():
8384 """
8485 init_mode = theano .config .compute_test_value
8586 theano .config .compute_test_value = 'raise'
86-
87+
8788 try :
8889 mlp = MLP (layers = [Sigmoid (dim = 10 , layer_name = 'Y' , irange = 0.1 )],
8990 nvis = 10 )
@@ -101,7 +102,7 @@ def test_kl():
101102 np .testing .assert_raises (ValueError , kl , Y , Y_hat , 1 )
102103 Y .tag .test_value [2 ][3 ] = - 0.1
103104 np .testing .assert_raises (ValueError , kl , Y , Y_hat , 1 )
104-
105+
105106 finally :
106107 theano .config .compute_test_value = init_mode
107108
@@ -112,10 +113,10 @@ def test_elemwise_kl():
112113 input.
113114 """
114115 init_mode = theano .config .compute_test_value
115- theano .config .compute_test_value = 'raise'
116-
116+ theano .config .compute_test_value = 'raise'
117+
117118 try :
118- mlp = MLP (layers = [Sigmoid (dim = 10 , layer_name = 'Y' , irange = 0.1 )],
119+ mlp = MLP (layers = [Sigmoid (dim = 10 , layer_name = 'Y' , irange = 0.1 )],
119120 nvis = 10 )
120121 X = mlp .get_input_space ().make_theano_batch ()
121122 Y = mlp .get_output_space ().make_theano_batch ()
@@ -131,8 +132,29 @@ def test_elemwise_kl():
131132 np .testing .assert_raises (ValueError , elemwise_kl , Y , Y_hat )
132133 Y .tag .test_value [2 ][3 ] = - 0.1
133134 np .testing .assert_raises (ValueError , elemwise_kl , Y , Y_hat )
134-
135+
135136 finally :
136137 theano .config .compute_test_value = init_mode
137138
138-
139+ def test_arg_of_sigmoid_good ():
140+ """
141+ Tests that arg_of_sigmoid works when given a good input.
142+ """
143+
144+ X = T .matrix ()
145+ Y = T .nnet .sigmoid (X )
146+ Z = arg_of_sigmoid (Y )
147+ assert X is Z
148+
149+ def test_arg_of_sigmoid_bad ():
150+ """
151+ Tests that arg_of_sigmoid raises an error when given a bad input.
152+ """
153+
154+ X = T .matrix ()
155+ Y = T .nnet .softmax (X )
156+ try :
157+ Z = arg_of_sigmoid (Y )
158+ except TypeError :
159+ return
160+ assert False # Should have failed
0 commit comments