Skip to content

Commit

Permalink
Modified some confusing comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
somepago authored Dec 6, 2021
1 parent 2f697e0 commit e288e84
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,13 +179,13 @@
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
optimizer.zero_grad()
# x_categ is the the categorical data, with y appended as last feature. x_cont has continuous data. cat_mask is an array of ones same shape as x_categ except for last column(corresponding to y's) set to 0s. con_mask is an array of ones same shape as x_cont.
# x_categ is the the categorical data, x_cont has continuous data, y_gts has ground truth ys. cat_mask is an array of ones same shape as x_categ and an additional column(corresponding to CLS token) set to 0s. con_mask is an array of ones same shape as x_cont.
x_categ, x_cont, y_gts, cat_mask, con_mask = data[0].to(device), data[1].to(device),data[2].to(device),data[3].to(device),data[4].to(device)

# We are converting the data to embeddings in the next step
_ , x_categ_enc, x_cont_enc = embed_data_mask(x_categ, x_cont, cat_mask, con_mask,model,vision_dset)
reps = model.transformer(x_categ_enc, x_cont_enc)
# select only the representations corresponding to y and apply mlp on it in the next step to get the predictions.
# select only the representations corresponding to CLS token and apply mlp on it in the next step to get the predictions.
y_reps = reps[:,0,:]

y_outs = model.mlpfory(y_reps)
Expand Down

0 comments on commit e288e84

Please sign in to comment.