1
- from typing import Tuple , Callable , Optional , TypeVar , Dict
1
+ from typing import Any , Tuple , Callable , Optional , TypeVar , Dict
2
2
3
3
from ..model import Model
4
4
from ..config import registry
5
5
from ..types import ArrayXd , XY_XY_OutT
6
6
from ..util import get_width
7
7
8
8
9
- InT = TypeVar ("InT" , bound = ArrayXd )
9
+ InT = TypeVar ("InT" , bound = Any )
10
10
OutT = TypeVar ("OutT" , bound = ArrayXd )
11
11
12
12
@@ -30,7 +30,7 @@ def add(
30
30
return Model ("add" , forward , init = init , dims = dims , layers = layers )
31
31
32
32
33
- def forward (model : Model [InT , InT ], X : InT , is_train : bool ) -> Tuple [InT , Callable ]:
33
+ def forward (model : Model [InT , OutT ], X : InT , is_train : bool ) -> Tuple [OutT , Callable ]:
34
34
if not model .layers :
35
35
return X , lambda dY : dY
36
36
Y , first_callback = model .layers [0 ](X , is_train = is_train )
@@ -40,7 +40,7 @@ def forward(model: Model[InT, InT], X: InT, is_train: bool) -> Tuple[InT, Callab
40
40
Y += layer_Y
41
41
callbacks .append (layer_callback )
42
42
43
- def backprop (dY : InT ) -> InT :
43
+ def backprop (dY : InT ) -> OutT :
44
44
dX = first_callback (dY )
45
45
for callback in callbacks :
46
46
dX += callback (dY )
@@ -50,14 +50,20 @@ def backprop(dY: InT) -> InT:
50
50
51
51
52
52
def init (
53
- model : Model [InT , InT ], X : Optional [InT ] = None , Y : Optional [InT ] = None
54
- ) -> Model [InT , InT ]:
53
+ model : Model [InT , OutT ], X : Optional [InT ] = None , Y : Optional [OutT ] = None
54
+ ) -> Model [InT , OutT ]:
55
55
if X is not None :
56
56
if model .has_dim ("nI" ) is not False :
57
57
model .set_dim ("nI" , get_width (X ))
58
58
for layer in model .layers :
59
59
if layer .has_dim ("nI" ) is not False :
60
60
layer .set_dim ("nI" , get_width (X ))
61
+ if Y is not None :
62
+ if model .has_dim ("nO" ) is not False :
63
+ model .set_dim ("nO" , get_width (Y ))
64
+ for layer in model .layers :
65
+ if layer .has_dim ("nO" ) is not False :
66
+ layer .set_dim ("nO" , get_width (Y ))
61
67
for layer in model .layers :
62
68
layer .initialize (X = X , Y = Y )
63
69
model .set_dim ("nO" , model .layers [0 ].get_dim ("nO" ))
0 commit comments