Skip to content

Commit cc181e3

Browse files
authored
Update train_utils.py
1 parent bb83bd1 commit cc181e3

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

train_utils.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def get_optimizer(args, model):
5050

5151

5252
def train(model, data, optimizer, loss_fn, train_mask, val_mask):
53+
model.train()
5354
optimizer.zero_grad()
5455
preds = model(data)
5556
if len(data.y.shape) != 1:
@@ -437,4 +438,4 @@ def annotator(pred_texts, label_names):
437438
conf.append(0)
438439

439440
anno = torch.LongTensor(anno)
440-
return anno, conf
441+
return anno, conf

0 commit comments

Comments
 (0)