Skip to content

Commit

Permalink
push to device
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed Dec 23, 2024
1 parent 05260a1 commit 1c0f29e
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/transformers/loss/loss_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def ForCausalLMLoss(
):
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
Expand All @@ -52,6 +53,7 @@ def ForMaskedLMLoss(
):
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
labels = labels.to(logits.device)

# Flatten the tokens
logits = logits.view(-1, vocab_size)
Expand All @@ -73,6 +75,7 @@ def ForSequenceClassificationLoss(labels, pooled_logits, config, **kwargs):
else:
config.problem_type = "multi_label_classification"

labels = labels.to(pooled_logits.device)
if config.problem_type == "regression":
loss_fct = MSELoss()
if num_labels == 1:
Expand Down Expand Up @@ -109,7 +112,7 @@ def ForQuestionAnsweringLoss(start_logits, end_logits, start_positions, end_posi
def ForTokenClassification(logits, labels, config, **kwargs):
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.view(-1, config.num_labels)
labels = labels.view(-1)
labels = labels.view(-1).to(logits.device)
logits = logits.float()
# Flatten the tokens
return fixed_cross_entropy(logits, labels, **kwargs)
Expand Down

0 comments on commit 1c0f29e

Please sign in to comment.