diff --git a/model_train.py b/model_train.py index 6ba5249..33120d2 100644 --- a/model_train.py +++ b/model_train.py @@ -156,6 +156,7 @@ def train_model(num_classes, feature_extractor, encoder_criterion, dset_loaders, #model_init.to(device) + print ("Initializing an Adam optimizer") optimizer = optim.Adam(model_init.Tmodel.parameters(), lr = 0.003, weight_decay= 0.0001) @@ -334,4 +335,4 @@ def train_model(num_classes, feature_extractor, encoder_criterion, dset_loaders, torch.save(model_init.state_dict(), mypath + "/best_performing_model.pth") del model_init - del ref_model \ No newline at end of file + del ref_model