@@ -476,6 +476,23 @@ def _try_put_index(self) -> None:
476476 super ()._try_put_index ()
477477
478478
479+ class StreamingDataLoaderCollateFn :
480+ def __init__ (self , collate_fn : Optional [Callable ] = None ) -> None :
481+ self .collate_fn = collate_fn or default_collate
482+
483+ def __call__ (self , items : List [Any ]) -> Any :
484+ if len (items ) > 0 and isinstance (items [0 ], dict ) and __NUM_SAMPLES_YIELDED_KEY__ in items [0 ]:
485+ batch = self .collate_fn ([item [__SAMPLES_KEY__ ] for item in items ])
486+ return {
487+ __SAMPLES_KEY__ : batch ,
488+ __NUM_SAMPLES_YIELDED_KEY__ : [
489+ torch .cumsum ([torch .tensor (item [__NUM_SAMPLES_YIELDED_KEY__ ]) for item in items ][- 1 ], dim = 0 )
490+ ],
491+ }
492+
493+ return self .collate_fn (items )
494+
495+
479496class StreamingDataLoader (DataLoader ):
480497 r"""The StreamingDataLoader combines a dataset and a sampler, and provides an iterable over the given dataset.
481498
@@ -541,6 +558,7 @@ def __init__(
541558 prefetch_factor : Optional [int ] = None ,
542559 shuffle : Optional [bool ] = None ,
543560 drop_last : Optional [bool ] = False ,
561+ collate_fn : Optional [Callable ] = None ,
544562 ** kwargs : Any ,
545563 ) -> None : # pyright: ignore
546564 if not isinstance (dataset , (StreamingDataset , CombinedStreamingDataset )):
@@ -563,6 +581,9 @@ def __init__(
563581 if profile_batches and num_workers == 0 :
564582 raise ValueError ("Profiling is supported only with num_workers >= 1." )
565583
584+ if collate_fn :
585+ collate_fn = StreamingDataLoaderCollateFn (collate_fn )
586+
566587 self .current_epoch = 0
567588 self .batch_size = batch_size
568589 self .num_workers = num_workers
@@ -581,6 +602,7 @@ def __init__(
581602 batch_size = batch_size ,
582603 num_workers = num_workers ,
583604 prefetch_factor = (10 if num_workers > 0 else None ) if prefetch_factor is None else prefetch_factor ,
605+ collate_fn = collate_fn ,
584606 ** kwargs ,
585607 ) # type: ignore
586608
0 commit comments