Skip to content
Open
8 changes: 8 additions & 0 deletions src/fairchem/core/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ class DistributedInitMethod(StrEnum):
FILE = "file"


class StartMethod(StrEnum):
SPAWN = "spawn"
FORK = "fork"
FORKSERVER = "forkserver"


@dataclass
class SlurmConfig:
mem_gb: int = 80
Expand All @@ -98,6 +104,7 @@ class SlurmConfig:
class SchedulerConfig:
mode: SchedulerType = SchedulerType.LOCAL
distributed_init_method: DistributedInitMethod = DistributedInitMethod.TCP
start_method: StartMethod = StartMethod.SPAWN # this is only used when `mode=LOCAL`
ranks_per_node: int = 1
num_nodes: int = 1
num_array_jobs: int = 1
Expand Down Expand Up @@ -495,6 +502,7 @@ def main(
min_nodes=1,
max_nodes=1,
nproc_per_node=scheduler_cfg.ranks_per_node,
start_method=scheduler_cfg.start_method,
rdzv_backend="c10d",
max_restarts=0,
)
Expand Down
4 changes: 2 additions & 2 deletions src/fairchem/core/components/common/dataloader_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


def get_dataloader(
dataset: Dataset, batch_sampler_fn: callable, collate_fn: callable, num_workers
dataset: Dataset, batch_sampler_fn: callable, collate_fn: callable, num_workers, pin_memory=True,
) -> DataLoader:
if gp_utils.initialized():
num_replicas = gp_utils.get_dp_world_size()
Expand All @@ -35,7 +35,7 @@ def get_dataloader(
dataset=dataset,
collate_fn=collate_fn,
num_workers=num_workers,
pin_memory=True,
pin_memory=pin_memory,
batch_sampler=batch_sampler,
)
logging.info("get_dataloader::Done!")
Expand Down
Loading