Skip to content

Commit 5ca2713

Browse files
Chetter2facebook-github-bot
authored andcommitted
Fix performance of WeightedRandomSampler (pytorch#10636)
Summary: Since pytorch#8958 was merged, the BatchSampler samples 0d tensors from WeightedRandomSampler instead of integers. It significantly reduces performance. This PR fix it the same way as pytorch#10361 fix DistributedSampler. Pull Request resolved: pytorch#10636 Differential Revision: D9423869 Pulled By: zou3519 fbshipit-source-id: f94da2d4cccf70e63beea6cfc3d1230b5610ae44
1 parent 0e30fa6 commit 5ca2713

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

torch/utils/data/sampler.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def __init__(self, weights, num_samples, replacement=True):
9595
self.replacement = replacement
9696

9797
def __iter__(self):
98-
return iter(torch.multinomial(self.weights, self.num_samples, self.replacement))
98+
return iter(torch.multinomial(self.weights, self.num_samples, self.replacement).tolist())
9999

100100
def __len__(self):
101101
return self.num_samples

0 commit comments

Comments
 (0)