Skip to content

Commit 3795566

Browse files
Ruilin Chenfacebook-github-bot
authored andcommitted
add in forward in training state (#3280)
Summary: Pull Request resolved: #3280 as title. add state dict with Enums for failure recover cases to check training state. Reviewed By: malaybag, TroyGarden Differential Revision: D80054126 fbshipit-source-id: 1e9927efc091b2a6d5917157cf82bed5758a3b22
1 parent fdd8534 commit 3795566

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
PrefetchPipelinedForward,
5454
)
5555
from torchrec.distributed.train_pipeline.tracing import PipelinedPostproc
56+
from torchrec.distributed.train_pipeline.types import PipelineState
5657
from torchrec.distributed.train_pipeline.utils import (
5758
_override_input_dist_forwards,
5859
_pipeline_detach_model,
@@ -72,6 +73,7 @@
7273
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
7374
from torchrec.streamable import Pipelineable
7475

76+
7577
logger: logging.Logger = logging.getLogger(__name__)
7678

7779
# This is required to support older torch package export for older models
@@ -104,6 +106,10 @@ class TrainPipeline(abc.ABC, Generic[In, Out]):
104106
def progress(self, dataloader_iter: Iterator[In]) -> Out:
105107
pass
106108

109+
def __init__(self) -> None:
110+
# pipeline state such as in foward, in backward etc, used in training recover scenarios
111+
self._state: PipelineState = PipelineState.IDLE
112+
107113
def sync_embeddings(
108114
self,
109115
model: torch.nn.Module,
@@ -192,6 +198,7 @@ def __init__(
192198
self._cur_batch: Optional[In] = None
193199
self._connected = False
194200
self._data_iter_stopped = False
201+
super().__init__()
195202

196203
def _reset_data_iter(self) -> None:
197204
self._connected = False
@@ -311,6 +318,7 @@ def __init__(
311318
self._cur_batch: Optional[In] = None
312319

313320
def progress(self, dataloader_iter: Iterator[In]) -> Out:
321+
self._state = PipelineState.IDLE
314322
if self._iter == 0:
315323
# Turn on sync collectives for PT2 pipeline.
316324
# To have similar logic between compiled/graph_break ranks.
@@ -335,6 +343,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
335343
self._optimizer.zero_grad()
336344

337345
with record_function("## forward ##"):
346+
self._state = PipelineState.CALL_FWD
338347
if self._iter == cc.compile_on_iter:
339348
logger.info("Compiling model...")
340349
if self._pre_compile_fn:
@@ -362,6 +371,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
362371

363372
if self._model.training:
364373
with record_function("## backward ##"):
374+
self._state = PipelineState.CALL_BWD
365375
torch.sum(losses).backward()
366376

367377
with record_function("## optimizer ##"):
@@ -478,11 +488,13 @@ def __init__(
478488
self._dmp_collection_sync_interval_batches = (
479489
dmp_collection_sync_interval_batches
480490
)
491+
481492
if self._dmp_collection_sync_interval_batches is not None:
482493
logger.info(
483494
f"{self.__class__.__name__}: [Sparse 2D] DMP collection will sync every "
484495
f"{self._dmp_collection_sync_interval_batches} batches"
485496
)
497+
super().__init__()
486498

487499
# DEPRECATED FIELDS
488500
self._batch_i: Optional[In] = None
@@ -634,6 +646,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
634646
batches[2]: i+2 batch, for copy_batch_to_gpu (expecting non-exhausted dataloader iter)
635647
"""
636648

649+
self._state = PipelineState.IDLE
637650
# attach the model just in case the user forgets to call it, especially when the user
638651
# pauses the pipeline.progress and detach the model for other purpose.
639652
if not self._model_attached:
@@ -667,6 +680,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
667680

668681
# forward
669682
with record_function("## forward ##"):
683+
self._state = PipelineState.CALL_FWD
670684
losses, output = self._model_fwd(self.batches[0])
671685

672686
if self._enqueue_batch_after_forward:
@@ -681,6 +695,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
681695

682696
if self._model.training:
683697
# backward
698+
self._state = PipelineState.CALL_BWD
684699
self._backward(losses)
685700

686701
self.sync_embeddings(

torchrec/distributed/train_pipeline/types.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# pyre-strict
99
import abc
1010
from dataclasses import dataclass
11+
from enum import Enum, unique
1112
from typing import Any, Dict, List, Tuple
1213

1314

@@ -84,3 +85,14 @@ def build_args_kwargs(
8485
key: arg.process_steps(initial_input) for key, arg in self.kwargs.items()
8586
}
8687
return args, kwargs
88+
89+
90+
@unique
91+
class PipelineState(Enum):
92+
"""
93+
Pipeline state for the train pipeline.
94+
"""
95+
96+
IDLE = 0
97+
CALL_FWD = 1
98+
CALL_BWD = 2

0 commit comments

Comments
 (0)