Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions recipes/full_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from torchtune.utils import get_world_size_and_rank
from tqdm import tqdm

from datetime import timedelta

class FullDPORecipeDistributed(FTRecipeInterface):
"""
Expand Down Expand Up @@ -153,8 +154,9 @@ def __init__(self, cfg: DictConfig) -> None:
self.distributed_backend = training.get_distributed_backend(
cfg.device, offload_ops_to_cpu=True
)
init_process_group(self.distributed_backend)
self._checkpoint_client = CheckpointClient(cfg)
# delay ProcessGroupNCCL.cpp Watchdog caught collective operation timeout to 1 hour
init_process_group(backend=self.distributed_backend, timeout=timedelta(seconds=3600))
self._checkpoint_client = CheckpointClient(cfg)

self.world_size, self.rank = get_world_size_and_rank()
self._is_rank_zero = self.rank == 0
Expand Down
6 changes: 4 additions & 2 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@

from tqdm import tqdm

from datetime import timedelta

class FullFinetuneRecipeDistributed(FTRecipeInterface):
"""
Expand Down Expand Up @@ -148,8 +149,9 @@ def __init__(self, cfg: DictConfig) -> None:
offload_ops_to_cpu=self.fsdp_cpu_offload
or self._enable_async_checkpointing,
)
init_process_group(self.distributed_backend)

# delay ProcessGroupNCCL.cpp Watchdog caught collective operation timeout to 1 hour
init_process_group(backend=self.distributed_backend, timeout=timedelta(seconds=3600))

# Initialize distributed variables
self.world_size, self.rank = utils.get_world_size_and_rank()
self._is_rank_zero = self.rank == 0
Expand Down
6 changes: 4 additions & 2 deletions recipes/knowledge_distillation_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@

from tqdm import tqdm

from datetime import timedelta


class KDRecipeDistributed(FTRecipeInterface):
"""
Expand Down Expand Up @@ -124,8 +126,8 @@ def __init__(self, cfg: DictConfig) -> None:
offload_ops_to_cpu=self.fsdp_cpu_offload
or self._enable_async_checkpointing,
)
init_process_group(self.distributed_backend)

# delay ProcessGroupNCCL.cpp Watchdog caught collective operation timeout to 1 hour
init_process_group(backend=self.distributed_backend, timeout=timedelta(seconds=3600))
self.world_size, self.rank = utils.get_world_size_and_rank()

self._is_rank_zero = self.rank == 0
Expand Down
4 changes: 3 additions & 1 deletion recipes/lora_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
)
from tqdm import tqdm

from datetime import timedelta

class LoRADPORecipeDistributed(FTRecipeInterface):
"""
Expand Down Expand Up @@ -143,7 +144,8 @@ def __init__(self, cfg: DictConfig) -> None:
cfg.device, offload_ops_to_cpu=True
)

init_process_group(self.distributed_backend)
# delay ProcessGroupNCCL.cpp Watchdog caught collective operation timeout to 1 hour
init_process_group(backend=self.distributed_backend, timeout=timedelta(seconds=3600))

self.world_size, self.rank = utils.get_world_size_and_rank()

Expand Down
7 changes: 5 additions & 2 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
)
from tqdm import tqdm

from datetime import timedelta


class LoRAFinetuneRecipeDistributed(FTRecipeInterface):
"""
Expand Down Expand Up @@ -148,8 +150,9 @@ def __init__(self, cfg: DictConfig) -> None:
offload_ops_to_cpu=self.fsdp_cpu_offload
or self._enable_async_checkpointing,
)
init_process_group(self.distributed_backend)

# delay ProcessGroupNCCL.cpp Watchdog caught collective operation timeout to 1 hour
init_process_group(backend=self.distributed_backend, timeout=timedelta(seconds=3600))

self.world_size, self.rank = utils.get_world_size_and_rank()

self._is_rank_zero = self.rank == 0
Expand Down
5 changes: 4 additions & 1 deletion recipes/qat_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@

from tqdm import tqdm

from datetime import timedelta


class QATRecipeDistributed(FTRecipeInterface):
"""
Expand Down Expand Up @@ -157,7 +159,8 @@ def __init__(self, cfg: DictConfig) -> None:
offload_ops_to_cpu=self.fsdp_cpu_offload
or self._enable_async_checkpointing,
)
init_process_group(self.distributed_backend)
# delay ProcessGroupNCCL.cpp Watchdog caught collective operation timeout to 1 hour
init_process_group(backend=self.distributed_backend, timeout=timedelta(seconds=3600))

# Initialize distributed variables
self.world_size, self.rank = utils.get_world_size_and_rank()
Expand Down
5 changes: 4 additions & 1 deletion recipes/qat_lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
from torchtune.training.quantization import swap_lora_linear_with_qat
from tqdm import tqdm

from datetime import timedelta


class QATLoRAFinetuneRecipeDistributed(FTRecipeInterface):
"""
Expand Down Expand Up @@ -156,7 +158,8 @@ def __init__(self, cfg: DictConfig) -> None:
offload_ops_to_cpu=self.fsdp_cpu_offload
or self._enable_async_checkpointing,
)
init_process_group(self.distributed_backend)
# delay ProcessGroupNCCL.cpp Watchdog caught collective operation timeout to 1 hour
init_process_group(backend=self.distributed_backend, timeout=timedelta(seconds=3600))

self.world_size, self.rank = utils.get_world_size_and_rank()

Expand Down