@@ -25,8 +25,6 @@ def LSTM(
25
25
msg = "LSTM depth must be at least 1. Maybe we should make this a noop?"
26
26
raise ValueError (msg )
27
27
28
- if bi and nO is not None :
29
- nO //= 2
30
28
model : Model [Padded , Padded ] = Model (
31
29
"lstm" ,
32
30
forward ,
@@ -48,10 +46,11 @@ def PyTorchLSTM(
48
46
49
47
if depth == 0 :
50
48
return noop () # type: ignore
49
+ nH = nO
51
50
if bi :
52
- nO = nO // 2
51
+ nH = nO // 2
53
52
pytorch_rnn = PyTorchRNNWrapper (
54
- torch .nn .LSTM (nI , nO , depth , bidirectional = bi , dropout = dropout )
53
+ torch .nn .LSTM (nI , nH , depth , bidirectional = bi , dropout = dropout )
55
54
)
56
55
pytorch_rnn .set_dim ("nO" , nO )
57
56
pytorch_rnn .set_dim ("nI" , nI )
@@ -69,7 +68,7 @@ def init(
69
68
model .set_dim ("nI" , get_width (X ))
70
69
if Y is not None :
71
70
model .set_dim ("nO" , get_width (Y ))
72
- nO = model .get_dim ("nO" )
71
+ nH = int ( model .get_dim ("nO" ) / model . get_dim ( "dirs" ) )
73
72
nI = model .get_dim ("nI" )
74
73
depth = model .get_dim ("depth" )
75
74
dirs = model .get_dim ("dirs" )
@@ -84,30 +83,30 @@ def init(
84
83
for i in range (depth ):
85
84
for j in range (dirs ):
86
85
# Input-to-gates weights and biases.
87
- params .append (init_W ((nO , layer_nI )))
88
- params .append (init_W ((nO , layer_nI )))
89
- params .append (init_W ((nO , layer_nI )))
90
- params .append (init_W ((nO , layer_nI )))
91
- params .append (init_b ((nO ,)))
92
- params .append (init_b ((nO ,)))
93
- params .append (init_b ((nO ,)))
94
- params .append (init_b ((nO ,)))
86
+ params .append (init_W ((nH , layer_nI )))
87
+ params .append (init_W ((nH , layer_nI )))
88
+ params .append (init_W ((nH , layer_nI )))
89
+ params .append (init_W ((nH , layer_nI )))
90
+ params .append (init_b ((nH ,)))
91
+ params .append (init_b ((nH ,)))
92
+ params .append (init_b ((nH ,)))
93
+ params .append (init_b ((nH ,)))
95
94
# Hidden-to-gates weights and biases
96
- params .append (init_W ((nO , nO )))
97
- params .append (init_W ((nO , nO )))
98
- params .append (init_W ((nO , nO )))
99
- params .append (init_W ((nO , nO )))
100
- params .append (init_b ((nO ,)))
101
- params .append (init_b ((nO ,)))
102
- params .append (init_b ((nO ,)))
103
- params .append (init_b ((nO ,)))
104
- layer_nI = nO * dirs
95
+ params .append (init_W ((nH , nH )))
96
+ params .append (init_W ((nH , nH )))
97
+ params .append (init_W ((nH , nH )))
98
+ params .append (init_W ((nH , nH )))
99
+ params .append (init_b ((nH ,)))
100
+ params .append (init_b ((nH ,)))
101
+ params .append (init_b ((nH ,)))
102
+ params .append (init_b ((nH ,)))
103
+ layer_nI = nH * dirs
105
104
model .set_param ("LSTM" , model .ops .xp .concatenate ([p .ravel () for p in params ]))
106
- model .set_param ("HC0" , zero_init (model .ops , (2 , depth , dirs , nO )))
105
+ model .set_param ("HC0" , zero_init (model .ops , (2 , depth , dirs , nH )))
107
106
size = model .get_param ("LSTM" ).size
108
- expected = 4 * dirs * nO * (nO + nI ) + dirs * (8 * nO )
107
+ expected = 4 * dirs * nH * (nH + nI ) + dirs * (8 * nH )
109
108
for _ in range (1 , depth ):
110
- expected += 4 * dirs * (nO + nO * dirs ) * nO + dirs * (8 * nO )
109
+ expected += 4 * dirs * (nH + nH * dirs ) * nH + dirs * (8 * nH )
111
110
assert size == expected , (size , expected )
112
111
113
112
0 commit comments