Skip to content

Commit fa3e4f3

Browse files
Merge pull request #585 from mlcommons/dev
dev -> main
2 parents 38ff276 + 322014c commit fa3e4f3

File tree

1 file changed

+1
-3
lines changed
  • algorithmic_efficiency/workloads/mnist/mnist_pytorch

1 file changed

+1
-3
lines changed

algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,7 @@ def _build_input_queue(
8080
weights = torch.as_tensor(
8181
batch['weights'], dtype=torch.bool, device=DEVICE)
8282
else:
83-
weights = torch.ones((batch['targets'].shape[-1],),
84-
dtype=torch.bool,
85-
device=DEVICE)
83+
weights = torch.ones_like(targets, dtype=torch.bool, device=DEVICE)
8684
# Send batch to other devices when using DDP.
8785
if USE_PYTORCH_DDP:
8886
dist.broadcast(inputs, src=0)

0 commit comments

Comments
 (0)