We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents 38ff276 + 322014c commit fa3e4f3Copy full SHA for fa3e4f3
algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py
@@ -80,9 +80,7 @@ def _build_input_queue(
80
weights = torch.as_tensor(
81
batch['weights'], dtype=torch.bool, device=DEVICE)
82
else:
83
- weights = torch.ones((batch['targets'].shape[-1],),
84
- dtype=torch.bool,
85
- device=DEVICE)
+ weights = torch.ones_like(targets, dtype=torch.bool, device=DEVICE)
86
# Send batch to other devices when using DDP.
87
if USE_PYTORCH_DDP:
88
dist.broadcast(inputs, src=0)
0 commit comments