diff --git a/src/fairchem/core/_cli.py b/src/fairchem/core/_cli.py index 68cdb4103d..109186f39b 100644 --- a/src/fairchem/core/_cli.py +++ b/src/fairchem/core/_cli.py @@ -79,6 +79,12 @@ class DistributedInitMethod(StrEnum): FILE = "file" +class StartMethod(StrEnum): + SPAWN = "spawn" + FORK = "fork" + FORKSERVER = "forkserver" + + @dataclass class SlurmConfig: mem_gb: int = 80 @@ -98,6 +104,7 @@ class SlurmConfig: class SchedulerConfig: mode: SchedulerType = SchedulerType.LOCAL distributed_init_method: DistributedInitMethod = DistributedInitMethod.TCP + start_method: StartMethod = StartMethod.SPAWN # this is only used when `mode=LOCAL` ranks_per_node: int = 1 num_nodes: int = 1 num_array_jobs: int = 1 @@ -495,6 +502,7 @@ def main( min_nodes=1, max_nodes=1, nproc_per_node=scheduler_cfg.ranks_per_node, + start_method=scheduler_cfg.start_method, rdzv_backend="c10d", max_restarts=0, ) diff --git a/src/fairchem/core/components/common/dataloader_builder.py b/src/fairchem/core/components/common/dataloader_builder.py index 0368cf2300..67f9656ea5 100644 --- a/src/fairchem/core/components/common/dataloader_builder.py +++ b/src/fairchem/core/components/common/dataloader_builder.py @@ -15,7 +15,7 @@ def get_dataloader( - dataset: Dataset, batch_sampler_fn: callable, collate_fn: callable, num_workers + dataset: Dataset, batch_sampler_fn: callable, collate_fn: callable, num_workers, pin_memory=True, ) -> DataLoader: if gp_utils.initialized(): num_replicas = gp_utils.get_dp_world_size() @@ -35,7 +35,7 @@ def get_dataloader( dataset=dataset, collate_fn=collate_fn, num_workers=num_workers, - pin_memory=True, + pin_memory=pin_memory, batch_sampler=batch_sampler, ) logging.info("get_dataloader::Done!")