53
53
PrefetchPipelinedForward ,
54
54
)
55
55
from torchrec .distributed .train_pipeline .tracing import PipelinedPostproc
56
+ from torchrec .distributed .train_pipeline .types import PipelineState
56
57
from torchrec .distributed .train_pipeline .utils import (
57
58
_override_input_dist_forwards ,
58
59
_pipeline_detach_model ,
72
73
from torchrec .sparse .jagged_tensor import KeyedJaggedTensor
73
74
from torchrec .streamable import Pipelineable
74
75
76
+
75
77
logger : logging .Logger = logging .getLogger (__name__ )
76
78
77
79
# This is required to support older torch package export for older models
@@ -104,6 +106,10 @@ class TrainPipeline(abc.ABC, Generic[In, Out]):
104
106
def progress (self , dataloader_iter : Iterator [In ]) -> Out :
105
107
pass
106
108
109
+ def __init__ (self ) -> None :
110
+ # pipeline state such as in foward, in backward etc, used in training recover scenarios
111
+ self ._state : PipelineState = PipelineState .IDLE
112
+
107
113
def sync_embeddings (
108
114
self ,
109
115
model : torch .nn .Module ,
@@ -192,6 +198,7 @@ def __init__(
192
198
self ._cur_batch : Optional [In ] = None
193
199
self ._connected = False
194
200
self ._data_iter_stopped = False
201
+ super ().__init__ ()
195
202
196
203
def _reset_data_iter (self ) -> None :
197
204
self ._connected = False
@@ -311,6 +318,7 @@ def __init__(
311
318
self ._cur_batch : Optional [In ] = None
312
319
313
320
def progress (self , dataloader_iter : Iterator [In ]) -> Out :
321
+ self ._state = PipelineState .IDLE
314
322
if self ._iter == 0 :
315
323
# Turn on sync collectives for PT2 pipeline.
316
324
# To have similar logic between compiled/graph_break ranks.
@@ -335,6 +343,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
335
343
self ._optimizer .zero_grad ()
336
344
337
345
with record_function ("## forward ##" ):
346
+ self ._state = PipelineState .CALL_FWD
338
347
if self ._iter == cc .compile_on_iter :
339
348
logger .info ("Compiling model..." )
340
349
if self ._pre_compile_fn :
@@ -362,6 +371,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
362
371
363
372
if self ._model .training :
364
373
with record_function ("## backward ##" ):
374
+ self ._state = PipelineState .CALL_BWD
365
375
torch .sum (losses ).backward ()
366
376
367
377
with record_function ("## optimizer ##" ):
@@ -478,11 +488,13 @@ def __init__(
478
488
self ._dmp_collection_sync_interval_batches = (
479
489
dmp_collection_sync_interval_batches
480
490
)
491
+
481
492
if self ._dmp_collection_sync_interval_batches is not None :
482
493
logger .info (
483
494
f"{ self .__class__ .__name__ } : [Sparse 2D] DMP collection will sync every "
484
495
f"{ self ._dmp_collection_sync_interval_batches } batches"
485
496
)
497
+ super ().__init__ ()
486
498
487
499
# DEPRECATED FIELDS
488
500
self ._batch_i : Optional [In ] = None
@@ -634,6 +646,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
634
646
batches[2]: i+2 batch, for copy_batch_to_gpu (expecting non-exhausted dataloader iter)
635
647
"""
636
648
649
+ self ._state = PipelineState .IDLE
637
650
# attach the model just in case the user forgets to call it, especially when the user
638
651
# pauses the pipeline.progress and detach the model for other purpose.
639
652
if not self ._model_attached :
@@ -667,6 +680,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
667
680
668
681
# forward
669
682
with record_function ("## forward ##" ):
683
+ self ._state = PipelineState .CALL_FWD
670
684
losses , output = self ._model_fwd (self .batches [0 ])
671
685
672
686
if self ._enqueue_batch_after_forward :
@@ -681,6 +695,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
681
695
682
696
if self ._model .training :
683
697
# backward
698
+ self ._state = PipelineState .CALL_BWD
684
699
self ._backward (losses )
685
700
686
701
self .sync_embeddings (
0 commit comments