diff --git a/baler/modules/models.py b/baler/modules/models.py index 62075b92..c1d75d53 100644 --- a/baler/modules/models.py +++ b/baler/modules/models.py @@ -355,7 +355,7 @@ def __init__(self, n_features, z_dim, *args, **kwargs): nn.ReLU(), # nn.BatchNorm1d(self.q_z_output_dim), nn.Linear(self.q_z_mid_dim, self.q_z_output_dim), - nn.ReLU() + nn.ReLU(), # nn.BatchNorm1d(42720) ) # Conv Layers @@ -613,7 +613,7 @@ def __init__(self, n_features, z_dim, *args, **kwargs): nn.ReLU(), # nn.BatchNorm1d(self.q_z_output_dim), nn.Linear(self.q_z_mid_dim, self.q_z_output_dim), - nn.ReLU() + nn.ReLU(), # nn.BatchNorm1d(42720) ) # Conv Layers