Skip to content

Commit 8c34a73

Browse files
author
Ian Goodfellow
committed
Merge pull request #1336 from goodfeli/thoughtididthis
added arg_of_sigmoid
2 parents 7b8bac8 + 43e43ca commit 8c34a73

File tree

2 files changed

+71
-9
lines changed

2 files changed

+71
-9
lines changed

pylearn2/expr/nnet.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,46 @@ def arg_of_softmax(Y_hat):
114114
return z
115115

116116

117+
def arg_of_sigmoid(Y_hat):
118+
"""
119+
Given the output of a call to theano.tensor.nnet.sigmoid,
120+
returns the argument to the sigmoid (by tracing the Theano
121+
graph).
122+
123+
Parameters
124+
----------
125+
Y_hat : Variable
126+
T.nnet.sigmoid(Z)
127+
128+
Returns
129+
-------
130+
Z : Variable
131+
The variable that was passed to T.nnet.sigmoid to create `Y_hat`.
132+
Raises an error if `Y_hat` is not actually the output of a theano
133+
sigmoid.
134+
"""
135+
assert hasattr(Y_hat, 'owner')
136+
owner = Y_hat.owner
137+
assert owner is not None
138+
op = owner.op
139+
if isinstance(op, Print):
140+
assert len(owner.inputs) == 1
141+
Y_hat, = owner.inputs
142+
owner = Y_hat.owner
143+
op = owner.op
144+
success = False
145+
if isinstance(op, T.Elemwise):
146+
if isinstance(op.scalar_op, T.nnet.sigm.ScalarSigmoid):
147+
success = True
148+
if not success:
149+
raise TypeError("Expected Y_hat to be the output of a sigmoid, "
150+
"but it appears to be the output of " + str(op) +
151+
" of type " + str(type(op)))
152+
z, = owner.inputs
153+
assert z.ndim == 2
154+
return z
155+
156+
117157
def kl(Y, Y_hat, batch_axis):
118158
"""
119159
Warning: This function expects a sigmoid nonlinearity in the
@@ -323,4 +363,4 @@ def compute_f1(precision, recall):
323363
"""
324364
f1 = (2. * precision * recall /
325365
T.maximum(1, precision + recall))
326-
return f1
366+
return f1

pylearn2/expr/tests/test_nnet.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414
from theano import tensor as T
1515

1616
from pylearn2.models.mlp import MLP, Sigmoid
17+
from pylearn2.expr.nnet import arg_of_sigmoid
1718
from pylearn2.expr.nnet import pseudoinverse_softmax_numpy
1819
from pylearn2.expr.nnet import softmax_numpy
1920
from pylearn2.expr.nnet import softmax_ratio
2021
from pylearn2.expr.nnet import compute_recall
2122
from pylearn2.expr.nnet import kl
22-
from pylearn2.expr.nnet import elemwise_kl
23+
from pylearn2.expr.nnet import elemwise_kl
2324
from pylearn2.utils import sharedX
2425

2526

@@ -83,7 +84,7 @@ def test_kl():
8384
"""
8485
init_mode = theano.config.compute_test_value
8586
theano.config.compute_test_value = 'raise'
86-
87+
8788
try:
8889
mlp = MLP(layers=[Sigmoid(dim=10, layer_name='Y', irange=0.1)],
8990
nvis=10)
@@ -101,7 +102,7 @@ def test_kl():
101102
np.testing.assert_raises(ValueError, kl, Y, Y_hat, 1)
102103
Y.tag.test_value[2][3] = -0.1
103104
np.testing.assert_raises(ValueError, kl, Y, Y_hat, 1)
104-
105+
105106
finally:
106107
theano.config.compute_test_value = init_mode
107108

@@ -112,10 +113,10 @@ def test_elemwise_kl():
112113
input.
113114
"""
114115
init_mode = theano.config.compute_test_value
115-
theano.config.compute_test_value = 'raise'
116-
116+
theano.config.compute_test_value = 'raise'
117+
117118
try:
118-
mlp = MLP(layers=[Sigmoid(dim=10, layer_name='Y', irange=0.1)],
119+
mlp = MLP(layers=[Sigmoid(dim=10, layer_name='Y', irange=0.1)],
119120
nvis=10)
120121
X = mlp.get_input_space().make_theano_batch()
121122
Y = mlp.get_output_space().make_theano_batch()
@@ -131,8 +132,29 @@ def test_elemwise_kl():
131132
np.testing.assert_raises(ValueError, elemwise_kl, Y, Y_hat)
132133
Y.tag.test_value[2][3] = -0.1
133134
np.testing.assert_raises(ValueError, elemwise_kl, Y, Y_hat)
134-
135+
135136
finally:
136137
theano.config.compute_test_value = init_mode
137138

138-
139+
def test_arg_of_sigmoid_good():
140+
"""
141+
Tests that arg_of_sigmoid works when given a good input.
142+
"""
143+
144+
X = T.matrix()
145+
Y = T.nnet.sigmoid(X)
146+
Z = arg_of_sigmoid(Y)
147+
assert X is Z
148+
149+
def test_arg_of_sigmoid_bad():
150+
"""
151+
Tests that arg_of_sigmoid raises an error when given a bad input.
152+
"""
153+
154+
X = T.matrix()
155+
Y = T.nnet.softmax(X)
156+
try:
157+
Z = arg_of_sigmoid(Y)
158+
except TypeError:
159+
return
160+
assert False # Should have failed

0 commit comments

Comments
 (0)