diff --git a/train/train_CSL_graph_classification.py b/train/train_CSL_graph_classification.py index 6b0952ebc..38c572f5f 100755 --- a/train/train_CSL_graph_classification.py +++ b/train/train_CSL_graph_classification.py @@ -16,108 +16,120 @@ def train_epoch_sparse(model, optimizer, device, data_loader, epoch): epoch_loss = 0 epoch_train_acc = 0 nb_data = 0 - gpu_mem = 0 + + optimizer.zero_grad() + for iter, (batch_graphs, batch_labels) in enumerate(data_loader): - batch_x = batch_graphs.ndata['feat'].to(device) # num x feat + batch_x = batch_graphs.ndata['feat'].to(device) batch_e = batch_graphs.edata['feat'].to(device) batch_graphs = batch_graphs.to(device) batch_labels = batch_labels.to(device) - optimizer.zero_grad() + try: batch_pos_enc = batch_graphs.ndata['pos_enc'].to(device) sign_flip = torch.rand(batch_pos_enc.size(1)).to(device) - sign_flip[sign_flip>=0.5] = 1.0; sign_flip[sign_flip<0.5] = -1.0 + sign_flip[sign_flip >= 0.5] = 1.0 + sign_flip[sign_flip < 0.5] = -1.0 batch_pos_enc = batch_pos_enc * sign_flip.unsqueeze(0) batch_scores = model.forward(batch_graphs, batch_x, batch_e, batch_pos_enc) except: batch_scores = model.forward(batch_graphs, batch_x, batch_e) - loss = model.loss(batch_scores, batch_labels) + + loss = model.loss(batch_scores, batch_labels) loss.backward() - optimizer.step() - epoch_loss += loss.detach().item() + + epoch_loss += loss.item() epoch_train_acc += accuracy(batch_scores, batch_labels) nb_data += batch_labels.size(0) + + optimizer.step() + optimizer.zero_grad() + epoch_loss /= (iter + 1) epoch_train_acc /= nb_data - + return epoch_loss, epoch_train_acc, optimizer + def evaluate_network_sparse(model, device, data_loader, epoch): model.eval() epoch_test_loss = 0 epoch_test_acc = 0 nb_data = 0 + with torch.no_grad(): for iter, (batch_graphs, batch_labels) in enumerate(data_loader): batch_x = batch_graphs.ndata['feat'].to(device) batch_e = batch_graphs.edata['feat'].to(device) batch_graphs = batch_graphs.to(device) batch_labels = batch_labels.to(device) + try: batch_pos_enc = batch_graphs.ndata['pos_enc'].to(device) batch_scores = model.forward(batch_graphs, batch_x, batch_e, batch_pos_enc) except: batch_scores = model.forward(batch_graphs, batch_x, batch_e) - loss = model.loss(batch_scores, batch_labels) - epoch_test_loss += loss.detach().item() + + loss = model.loss(batch_scores, batch_labels) + epoch_test_loss += loss.item() epoch_test_acc += accuracy(batch_scores, batch_labels) nb_data += batch_labels.size(0) - epoch_test_loss /= (iter + 1) - epoch_test_acc /= nb_data - - return epoch_test_loss, epoch_test_acc + epoch_test_loss /= (iter + 1) + epoch_test_acc /= nb_data + return epoch_test_loss, epoch_test_acc -""" - For WL-GNNs -""" -def train_epoch_dense(model, optimizer, device, data_loader, epoch, batch_size): +def train_epoch_dense(model, optimizer, device, data_loader, epoch): model.train() epoch_loss = 0 epoch_train_acc = 0 nb_data = 0 - gpu_mem = 0 + optimizer.zero_grad() + for iter, (x_with_node_feat, labels) in enumerate(data_loader): x_with_node_feat = x_with_node_feat.to(device) labels = labels.to(device) - + scores = model.forward(x_with_node_feat) - loss = model.loss(scores, labels) + loss = model.loss(scores, labels) loss.backward() - - if not (iter%batch_size): - optimizer.step() - optimizer.zero_grad() - - epoch_loss += loss.detach().item() + + epoch_loss += loss.item() epoch_train_acc += accuracy(scores, labels) nb_data += labels.size(0) + + optimizer.step() + optimizer.zero_grad() + epoch_loss /= (iter + 1) epoch_train_acc /= nb_data - + return epoch_loss, epoch_train_acc, optimizer + def evaluate_network_dense(model, device, data_loader, epoch): model.eval() epoch_test_loss = 0 epoch_test_acc = 0 nb_data = 0 + with torch.no_grad(): for iter, (x_with_node_feat, labels) in enumerate(data_loader): x_with_node_feat = x_with_node_feat.to(device) labels = labels.to(device) - + scores = model.forward(x_with_node_feat) - loss = model.loss(scores, labels) - epoch_test_loss += loss.detach().item() + loss = model.loss(scores, labels) + epoch_test_loss += loss.item() epoch_test_acc += accuracy(scores, labels) nb_data += labels.size(0) - epoch_test_loss /= (iter + 1) - epoch_test_acc /= nb_data - + + epoch_test_loss /= (iter + 1) + epoch_test_acc /= nb_data + return epoch_test_loss, epoch_test_acc @@ -128,4 +140,5 @@ def check_patience(all_losses, best_loss, best_epoch, curr_loss, curr_epoch, cou best_epoch = curr_epoch else: counter += 1 + return best_loss, best_epoch, counter