Skip to content

Commit d8e3429

Browse files
committed
add tensorboard to training script
Summary: - add tensorboard integration and separate the metrics by run id and replica id - have an output folder per replica id Test Plan: <img width="1159" alt="image" src="https://github.com/user-attachments/assets/44d60f20-c232-4c26-92a6-4749b846513b" />
1 parent 7898bfd commit d8e3429

File tree

2 files changed

+25
-6
lines changed

2 files changed

+25
-6
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,5 @@ dist/
3030

3131
# Torch
3232
cifar/
33+
34+
output/

train_diloco.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from torch.distributed.elastic.multiprocessing.errors import record
2424
from torch.distributed.pipelining import SplitPoint, pipeline
2525
from torch.export import export
26+
from torch.utils.tensorboard import SummaryWriter
2627
from torchdata.stateful_dataloader import StatefulDataLoader
2728

2829
from torchft import (
@@ -41,7 +42,11 @@
4142
@record
4243
def main() -> None:
4344
REPLICA_GROUP_ID = int(os.environ.get("REPLICA_GROUP_ID", 0))
44-
NUM_REPLICA_GROUPS = int(os.environ.get("NUM_REPLICA_GROUPS", 2))
45+
RUN = int(os.environ.get("RUN", 0))
46+
47+
output_folder = f"output/replica-{REPLICA_GROUP_ID}"
48+
49+
writer = SummaryWriter(f"{output_folder}/tensorboard", max_queue=1000)
4550

4651
def load_state_dict(state_dict):
4752
m.load_state_dict(state_dict["model"])
@@ -171,12 +176,12 @@ def forward(self, x):
171176
num_params = sum(p.numel() for p in m.parameters())
172177
print(f"Total number of parameters: {num_params}")
173178

174-
sort_by_keyword = "self_" + device + "_time_total"
175-
176179
def trace_handler(p):
177-
p.export_chrome_trace(
178-
f"/home/tushar00jain/trace_{p.step_num}_{REPLICA_GROUP_ID}.json"
179-
)
180+
dir = f"{output_folder}/profiles"
181+
if not os.path.exists(dir):
182+
os.makedirs(dir, exist_ok=True)
183+
184+
p.export_chrome_trace(f"{dir}/step-{p.step_num}.json")
180185

181186
# You can use an epoch based training but with faults it's easier to use step
182187
# based training.
@@ -188,6 +193,7 @@ def trace_handler(p):
188193
)
189194

190195
prof.start()
196+
tensorboard_key_prefix = f"Run:{RUN}"
191197
with DiLoCo(
192198
manager,
193199
module_partitions if USE_STREAMING else [m],
@@ -210,16 +216,27 @@ def trace_handler(p):
210216
out = m(inputs)
211217
loss = criterion(out, labels)
212218

219+
writer.add_scalar(f"{tensorboard_key_prefix}/loss", loss, i)
220+
213221
loss.backward()
214222

215223
inner_optimizer.step()
216224

225+
writer.add_scalar(
226+
f"{tensorboard_key_prefix}/num_participants",
227+
manager.num_participants(),
228+
i,
229+
)
230+
writer.add_scalar(
231+
f"{tensorboard_key_prefix}/current_step", manager.current_step(), i
232+
)
217233
if manager.current_step() % 100 == 0:
218234
print(f"[{manager.current_step()}] loss = {loss.item()}")
219235

220236
if manager.current_step() >= 15:
221237
# complete training
222238
prof.stop()
239+
writer.flush()
223240
exit()
224241

225242

0 commit comments

Comments
 (0)