Skip to content

Commit

Permalink
[aisingapore#61] Replaced deprecated functional.tanh with torch.tanh
Browse files Browse the repository at this point in the history
  • Loading branch information
ktyap committed Nov 4, 2022
1 parent 97bc12f commit 5da1a48
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions sgnlp/models/drnn_roberta/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def forward(self, M, x, mask=None):
M_ = M.transpose(0,1) # batch, seqlen, mem_dim
x_ = x.unsqueeze(1).expand(-1,M.size()[0],-1) # batch, seqlen, cand_dim
M_x_ = torch.cat([M_,x_],2) # batch, seqlen, mem_dim+cand_dim
mx_a = F.tanh(self.transform(M_x_)) # batch, seqlen, alpha_dim
mx_a = torch.tanh(self.transform(M_x_)) # batch, seqlen, alpha_dim
alpha = F.softmax(self.vector_prod(mx_a),1).transpose(1,2) # batch, 1, seqlen

attn_pool = torch.bmm(alpha, M.transpose(0,1))[:,0,:] # batch, mem_dim
Expand Down Expand Up @@ -426,7 +426,7 @@ def forward(
att_features = torch.cat(att_features, dim=0)
hidden = F.relu(self.linear(att_features))
else:
hidden = F.tanh(self.linear(features))
hidden = torch.tanh(self.linear(features))


log_prob = F.log_softmax(self.smax_fc(hidden), 2)
Expand Down
2 changes: 1 addition & 1 deletion sgnlp/models/drnn_roberta/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def forward(
att_features = torch.cat(att_features, dim=0)
hidden = F.relu(self.linear(att_features))
else:
hidden = F.tanh(self.linear(features))
hidden = torch.tanh(self.linear(features))


log_prob = F.log_softmax(self.smax_fc(hidden), 2)
Expand Down
2 changes: 1 addition & 1 deletion sgnlp/models/drnn_roberta/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def forward(self, M, x, mask=None):
M_ = M.transpose(0,1) # batch, seqlen, mem_dim
x_ = x.unsqueeze(1).expand(-1,M.size()[0],-1) # batch, seqlen, cand_dim
M_x_ = torch.cat([M_,x_],2) # batch, seqlen, mem_dim+cand_dim
mx_a = F.tanh(self.transform(M_x_)) # batch, seqlen, alpha_dim
mx_a = torch.tanh(self.transform(M_x_)) # batch, seqlen, alpha_dim
alpha = F.softmax(self.vector_prod(mx_a),1).transpose(1,2) # batch, 1, seqlen

attn_pool = torch.bmm(alpha, M.transpose(0,1))[:,0,:] # batch, mem_dim
Expand Down

0 comments on commit 5da1a48

Please sign in to comment.