Skip to content

Commit 01f78da

Browse files
committed
Merge branch 'master' of https://github.com/explosion/thinc
2 parents 63fff00 + 88c8808 commit 01f78da

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

thinc/layers/add.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
from typing import Tuple, Callable, Optional, TypeVar, Dict
1+
from typing import Any, Tuple, Callable, Optional, TypeVar, Dict
22

33
from ..model import Model
44
from ..config import registry
55
from ..types import ArrayXd, XY_XY_OutT
66
from ..util import get_width
77

88

9-
InT = TypeVar("InT", bound=ArrayXd)
9+
InT = TypeVar("InT", bound=Any)
1010
OutT = TypeVar("OutT", bound=ArrayXd)
1111

1212

@@ -30,7 +30,7 @@ def add(
3030
return Model("add", forward, init=init, dims=dims, layers=layers)
3131

3232

33-
def forward(model: Model[InT, InT], X: InT, is_train: bool) -> Tuple[InT, Callable]:
33+
def forward(model: Model[InT, OutT], X: InT, is_train: bool) -> Tuple[OutT, Callable]:
3434
if not model.layers:
3535
return X, lambda dY: dY
3636
Y, first_callback = model.layers[0](X, is_train=is_train)
@@ -40,7 +40,7 @@ def forward(model: Model[InT, InT], X: InT, is_train: bool) -> Tuple[InT, Callab
4040
Y += layer_Y
4141
callbacks.append(layer_callback)
4242

43-
def backprop(dY: InT) -> InT:
43+
def backprop(dY: InT) -> OutT:
4444
dX = first_callback(dY)
4545
for callback in callbacks:
4646
dX += callback(dY)
@@ -50,14 +50,20 @@ def backprop(dY: InT) -> InT:
5050

5151

5252
def init(
53-
model: Model[InT, InT], X: Optional[InT] = None, Y: Optional[InT] = None
54-
) -> Model[InT, InT]:
53+
model: Model[InT, OutT], X: Optional[InT] = None, Y: Optional[OutT] = None
54+
) -> Model[InT, OutT]:
5555
if X is not None:
5656
if model.has_dim("nI") is not False:
5757
model.set_dim("nI", get_width(X))
5858
for layer in model.layers:
5959
if layer.has_dim("nI") is not False:
6060
layer.set_dim("nI", get_width(X))
61+
if Y is not None:
62+
if model.has_dim("nO") is not False:
63+
model.set_dim("nO", get_width(Y))
64+
for layer in model.layers:
65+
if layer.has_dim("nO") is not False:
66+
layer.set_dim("nO", get_width(Y))
6167
for layer in model.layers:
6268
layer.initialize(X=X, Y=Y)
6369
model.set_dim("nO", model.layers[0].get_dim("nO"))

0 commit comments

Comments
 (0)