diff --git a/evals/video_classification_frozen/eval.py b/evals/video_classification_frozen/eval.py index f81f526d..44e38944 100644 --- a/evals/video_classification_frozen/eval.py +++ b/evals/video_classification_frozen/eval.py @@ -472,6 +472,7 @@ def make_dataloader( data_loader, _ = init_data( data=dataset_type, root_path=root_path, + training=training, transform=transform, batch_size=batch_size, world_size=world_size, diff --git a/src/datasets/data_manager.py b/src/datasets/data_manager.py index cdb7ade4..1080c1f6 100644 --- a/src/datasets/data_manager.py +++ b/src/datasets/data_manager.py @@ -63,6 +63,7 @@ def init_data( persistent_workers=persistent_workers, copy_data=copy_data, drop_last=drop_last, + shuffle=training, subset_file=subset_file) elif data.lower() == 'videodataset': @@ -86,6 +87,7 @@ def init_data( world_size=world_size, rank=rank, drop_last=drop_last, + shuffle=training, log_dir=log_dir) return (data_loader, dist_sampler) diff --git a/src/datasets/image_dataset.py b/src/datasets/image_dataset.py index 84e9b082..c45029a8 100644 --- a/src/datasets/image_dataset.py +++ b/src/datasets/image_dataset.py @@ -53,6 +53,7 @@ def make_imagedataset( copy_data=False, drop_last=True, persistent_workers=False, + shuffle=True, subset_file=None ): dataset = ImageFolder( @@ -64,7 +65,8 @@ def make_imagedataset( dist_sampler = torch.utils.data.distributed.DistributedSampler( dataset=dataset, num_replicas=world_size, - rank=rank) + rank=rank, + shuffle=shuffle) data_loader = torch.utils.data.DataLoader( dataset, collate_fn=collator, diff --git a/src/datasets/video_dataset.py b/src/datasets/video_dataset.py index b05cc701..130e8526 100644 --- a/src/datasets/video_dataset.py +++ b/src/datasets/video_dataset.py @@ -44,6 +44,7 @@ def make_videodataset( num_workers=10, pin_mem=True, duration=None, + shuffle=True, log_dir=None, ): dataset = VideoDataset( @@ -66,13 +67,13 @@ def make_videodataset( dataset.sample_weights, num_replicas=world_size, rank=rank, - shuffle=True) + shuffle=shuffle) else: dist_sampler = torch.utils.data.distributed.DistributedSampler( dataset, num_replicas=world_size, rank=rank, - shuffle=True) + shuffle=shuffle) data_loader = torch.utils.data.DataLoader( dataset,