diff --git a/allrank/models/transformer.py b/allrank/models/transformer.py index c09aaa8..dd1b964 100644 --- a/allrank/models/transformer.py +++ b/allrank/models/transformer.py @@ -102,8 +102,8 @@ def forward(self, x, sublayer): :param sublayer: layer through which to pass the input prior to applying the sum :return: output of shape [batch_size, slate_length, output_dim] """ - return x + self.dropout( - sublayer(self.norm(x))) + return self.norm(x + self.dropout( + sublayer(x))) class EncoderLayer(nn.Module):