Skip to content

Commit 815a6c7

Browse files
authoredNov 24, 2023
Update models.py
1 parent cc181e3 commit 815a6c7

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed
 

‎models.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,10 @@ def __init__(self, num_layers, input_dim, hidden_dimension, num_classes, dropout
164164
for _ in range(num_layers - 2):
165165
self.convs.append(GCNConv(hidden_dimension, hidden_dimension, cached=False,
166166
normalize=True))
167-
self.norms.append(torch.nn.BatchNorm1d(hidden_dimension))
167+
if norm:
168+
self.norms.append(torch.nn.BatchNorm1d(hidden_dimension))
169+
else:
170+
self.norms.append(torch.nn.Identity())
168171

169172
self.convs.append(GCNConv(hidden_dimension, num_classes, cached=False, normalize=True))
170173

@@ -599,4 +602,4 @@ def forward(self, graph, feat):
599602
h = self.bias_last(h)
600603

601604
return h
602-
605+

0 commit comments

Comments
 (0)