Skip to content

Commit

Permalink
Remove __len__ from CombinedStreamingDataset (#19321)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Jan 24, 2024
1 parent b446b08 commit 71bfdc3
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions src/lightning/data/streaming/combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,11 @@ class CombinedStreamingDataset(IterableDataset):
"""The `CombinedStreamingDataset` enables to stream data from multiple StreamingDataset with the sampling ratio of
your choice.
Addtionally, the `CombinedStreamingDataset` keeps track of the number of
samples fetched to enable resumability of the datasets.
Addtionally, the `CombinedStreamingDataset` keeps track of the number of samples fetched to enable resumability
of the datasets.
Note that due to the random sampling, the number of samples returned from the iterator is variable and a function
of the given seed. The combined dataset will raise a StopIteration as soon as any of the datasets is exhausted.
"""

Expand Down Expand Up @@ -71,10 +74,6 @@ def _set_use_streaming_dataloader(self, use_streaming_dataloader: bool) -> None:
# Used to prevent returning num_samples_yielded when using PyTorch DataLoader
self._use_streaming_dataloader = use_streaming_dataloader

def __len__(self) -> int:
assert self._weights
return int(min([1 / w * len(d) for w, d in zip(self._weights, self._datasets) if w > 0]))

def __iter__(self) -> Iterator[Any]:
assert self._weights

Expand Down

0 comments on commit 71bfdc3

Please sign in to comment.