23
23
from torch .distributed .elastic .multiprocessing .errors import record
24
24
from torch .distributed .pipelining import SplitPoint , pipeline
25
25
from torch .export import export
26
+ from torch .utils .tensorboard import SummaryWriter
26
27
from torchdata .stateful_dataloader import StatefulDataLoader
27
28
28
29
from torchft import (
41
42
@record
42
43
def main () -> None :
43
44
REPLICA_GROUP_ID = int (os .environ .get ("REPLICA_GROUP_ID" , 0 ))
44
- NUM_REPLICA_GROUPS = int (os .environ .get ("NUM_REPLICA_GROUPS" , 2 ))
45
+ RUN = int (os .environ .get ("RUN" , 0 ))
46
+
47
+ output_folder = f"output/replica-{ REPLICA_GROUP_ID } "
48
+
49
+ writer = SummaryWriter (f"{ output_folder } /tensorboard" , max_queue = 1000 )
45
50
46
51
def load_state_dict (state_dict ):
47
52
m .load_state_dict (state_dict ["model" ])
@@ -171,12 +176,12 @@ def forward(self, x):
171
176
num_params = sum (p .numel () for p in m .parameters ())
172
177
print (f"Total number of parameters: { num_params } " )
173
178
174
- sort_by_keyword = "self_" + device + "_time_total"
175
-
176
179
def trace_handler (p ):
177
- p .export_chrome_trace (
178
- f"/home/tushar00jain/trace_{ p .step_num } _{ REPLICA_GROUP_ID } .json"
179
- )
180
+ dir = f"{ output_folder } /profiles"
181
+ if not os .path .exists (dir ):
182
+ os .makedirs (dir , exist_ok = True )
183
+
184
+ p .export_chrome_trace (f"{ dir } /step-{ p .step_num } .json" )
180
185
181
186
# You can use an epoch based training but with faults it's easier to use step
182
187
# based training.
@@ -188,6 +193,7 @@ def trace_handler(p):
188
193
)
189
194
190
195
prof .start ()
196
+ tensorboard_key_prefix = f"Run:{ RUN } "
191
197
with DiLoCo (
192
198
manager ,
193
199
module_partitions if USE_STREAMING else [m ],
@@ -210,16 +216,27 @@ def trace_handler(p):
210
216
out = m (inputs )
211
217
loss = criterion (out , labels )
212
218
219
+ writer .add_scalar (f"{ tensorboard_key_prefix } /loss" , loss , i )
220
+
213
221
loss .backward ()
214
222
215
223
inner_optimizer .step ()
216
224
225
+ writer .add_scalar (
226
+ f"{ tensorboard_key_prefix } /num_participants" ,
227
+ manager .num_participants (),
228
+ i ,
229
+ )
230
+ writer .add_scalar (
231
+ f"{ tensorboard_key_prefix } /current_step" , manager .current_step (), i
232
+ )
217
233
if manager .current_step () % 100 == 0 :
218
234
print (f"[{ manager .current_step ()} ] loss = { loss .item ()} " )
219
235
220
236
if manager .current_step () >= 15 :
221
237
# complete training
222
238
prof .stop ()
239
+ writer .flush ()
223
240
exit ()
224
241
225
242
0 commit comments