Skip to content

Commit 3d12f28

Browse files
committed
Move Checkpointer to own file and restore single-device support
The Checkpointer class was specific to FSDP DTensor style sharded model checkpointing. I instead converted it to an abstract-ish base class and created two subclasses for single device and multi-device checkpointing respectively. Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
1 parent 7ff6186 commit 3d12f28

File tree

4 files changed

+161
-99
lines changed

4 files changed

+161
-99
lines changed

scripts/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
.to(DEVICE)
4343
)
4444

45-
setup_metric_logger(loggers=[], run_name=None)
45+
setup_metric_logger(loggers=[], run_name=None, output_dir="./logs")
4646
setup_root_logger()
4747
# END TEMP MODEL SETUP
4848

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
from abc import abstractmethod
2+
import os
3+
from pathlib import Path
4+
import torch
5+
from torch.distributed.checkpoint.state_dict import (
6+
get_model_state_dict,
7+
get_optimizer_state_dict,
8+
set_model_state_dict,
9+
set_optimizer_state_dict,
10+
StateDictOptions,
11+
)
12+
13+
import torch.distributed as dist
14+
15+
16+
class BaseCheckpointer:
17+
"""Helper class to save and load checkpoints.
18+
19+
Checkpoint file structure:
20+
../path/
21+
0/ # epoch number
22+
model_state_dict.pt
23+
optimizer_state_dict.pt
24+
1/
25+
model_state_dict.pt
26+
optimizer_state_dict.pt
27+
...
28+
"""
29+
30+
def __init__(self, path: Path | str, try_load_last_checkpoint: bool = True):
31+
self.path = Path(path)
32+
if try_load_last_checkpoint:
33+
self.previous_epoch: int = self._get_previous_epoch()
34+
else:
35+
self.previous_epoch: int = -1
36+
37+
@abstractmethod
38+
def load_model_state_dict(self, model: torch.nn.Module):
39+
raise NotImplementedError
40+
41+
@abstractmethod
42+
def load_optimizer_state_dict(
43+
self, model: torch.nn.Module, optimizer: torch.optim.Optimizer
44+
):
45+
raise NotImplementedError
46+
47+
@abstractmethod
48+
def save_checkpoint(
49+
self, model: torch.nn.Module, optimizer: torch.optim.Optimizer, epoch: int
50+
):
51+
raise NotImplementedError
52+
53+
def _get_previous_epoch(self) -> int:
54+
if not self.path.exists():
55+
return -1
56+
last_checkpoint_num = -1
57+
for d in self.path.iterdir():
58+
if d.is_dir():
59+
try:
60+
last_checkpoint_num = max(last_checkpoint_num, int(d.name))
61+
except ValueError:
62+
continue
63+
return last_checkpoint_num
64+
65+
def model_path(self, epoch: int):
66+
model_fname = "model_state_dict.pt"
67+
return self.path / str(epoch) / model_fname
68+
69+
def optimizer_path(self, epoch: int):
70+
optimizer_fname = "optimizer_state_dict.pt"
71+
return self.path / str(epoch) / optimizer_fname
72+
73+
74+
class SingleGPUCheckpointer(BaseCheckpointer):
75+
def load_model_state_dict(self, model: torch.nn.Module):
76+
full_state_dict = torch.load(
77+
self.model_path(self.previous_epoch),
78+
weights_only=True,
79+
map_location="cuda:0", # todo: make this configurable
80+
)
81+
model.load_state_dict(full_state_dict)
82+
83+
def load_optimizer_state_dict(
84+
self, model: torch.nn.Module, optimizer: torch.optim.Optimizer
85+
):
86+
full_state_dict = torch.load(
87+
self.optimizer_path(self.previous_epoch),
88+
weights_only=True,
89+
map_location="cuda:0", # todo: make this configurable
90+
)
91+
optimizer.load_state_dict(full_state_dict)
92+
93+
def save_checkpoint(
94+
self, model: torch.nn.Module, optimizer: torch.optim.Optimizer, epoch: int
95+
):
96+
os.makedirs(self.path / str(epoch), exist_ok=True)
97+
torch.save(model.state_dict(), self.model_path(epoch))
98+
torch.save(optimizer.state_dict(), self.optimizer_path(epoch))
99+
100+
101+
class DistributedCheckpointer(BaseCheckpointer):
102+
def load_model_state_dict(self, model: torch.nn.Module):
103+
full_state_dict = torch.load(
104+
self.model_path(self.previous_epoch),
105+
mmap=True,
106+
weights_only=True,
107+
map_location="cpu",
108+
)
109+
set_model_state_dict(
110+
model,
111+
full_state_dict,
112+
options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True),
113+
)
114+
dist.barrier()
115+
116+
def load_optimizer_state_dict(self, model, optimizer: torch.optim.Optimizer):
117+
full_state_dict = torch.load(
118+
self.optimizer_path(self.previous_epoch),
119+
mmap=True,
120+
weights_only=True,
121+
map_location="cpu",
122+
)
123+
set_optimizer_state_dict(
124+
model,
125+
optimizer,
126+
full_state_dict,
127+
options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True),
128+
)
129+
dist.barrier()
130+
131+
def save_checkpoint(
132+
self, model: torch.nn.Module, optimizer: torch.optim.Optimizer, epoch: int
133+
):
134+
model_state_dict = get_model_state_dict(
135+
model, options=StateDictOptions(full_state_dict=True, cpu_offload=True)
136+
)
137+
optimizer_state_dict = get_optimizer_state_dict(
138+
model,
139+
optimizer,
140+
options=StateDictOptions(full_state_dict=True, cpu_offload=True),
141+
)
142+
143+
if dist.get_rank() == 0:
144+
# Only rank 0 saves the checkpoint
145+
os.makedirs(self.path / str(epoch), exist_ok=True)
146+
torch.save(model_state_dict, self.model_path(epoch))
147+
torch.save(optimizer_state_dict, self.optimizer_path(epoch))
148+
149+
dist.barrier()

src/speculators/train/distributed_batch_sampler.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,6 @@ def _assign_to_packed_batches(
111111
# Break and drop whatever lengths we have left
112112
break
113113

114-
115114
# binary search in [1, 1 + upper bound for x)
116115
left = 1
117116
right = 1 + np.searchsorted(

src/speculators/train/trainer.py

Lines changed: 11 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,17 @@
1-
import os
2-
from pathlib import Path
31
import torch
42
from torch.distributed.fsdp import FSDPModule, fully_shard, MixedPrecisionPolicy
53
from torch.utils.data import DataLoader
64
from tqdm.rich import tqdm # todo: requries tqdm and rich
7-
from torch.distributed.checkpoint.state_dict import (
8-
get_model_state_dict,
9-
get_optimizer_state_dict,
10-
set_model_state_dict,
11-
set_optimizer_state_dict,
12-
StateDictOptions,
13-
)
145

156

167
import torch.distributed as dist
178
import logging
189

10+
from speculators.train.checkpointer import (
11+
SingleGPUCheckpointer,
12+
DistributedCheckpointer,
13+
)
14+
1915
root_logger = logging.getLogger("speculators")
2016
metric_logger = logging.getLogger("speculators.metrics")
2117

@@ -43,93 +39,6 @@ def compute_draft_accuracy(
4339
return torch.tensor(accuracies, device=target_logits.device)
4440

4541

46-
class Checkpointer:
47-
"""Helper class to save and load checkpoints.
48-
49-
Checkpoint file structure:
50-
../path/
51-
0/ # epoch number
52-
model_state_dict.pt
53-
optimizer_state_dict.pt
54-
1/
55-
model_state_dict.pt
56-
optimizer_state_dict.pt
57-
...
58-
"""
59-
60-
model_fname = "model_state_dict.pt"
61-
optimizer_fname = "optimizer_state_dict.pt"
62-
63-
def __init__(self, path: Path | str, try_load_last_checkpoint: bool = True):
64-
self.path = Path(path)
65-
if try_load_last_checkpoint:
66-
self.previous_epoch: int = self._get_previous_epoch()
67-
else:
68-
self.previous_epoch: int = -1
69-
70-
def _get_previous_epoch(self) -> int:
71-
if not self.path.exists():
72-
return -1
73-
last_checkpoint_num = -1
74-
for d in self.path.iterdir():
75-
if d.is_dir():
76-
try:
77-
last_checkpoint_num = max(last_checkpoint_num, int(d.name))
78-
except ValueError:
79-
continue
80-
return last_checkpoint_num
81-
82-
def load_model_state_dict(self, model: torch.nn.Module):
83-
full_state_dict = torch.load(
84-
self.path / str(self.previous_epoch) / self.model_fname,
85-
mmap=True,
86-
weights_only=True,
87-
map_location="cpu",
88-
)
89-
set_model_state_dict(
90-
model,
91-
full_state_dict,
92-
options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True),
93-
)
94-
dist.barrier()
95-
96-
def load_optimizer_state_dict(self, model, optimizer: torch.optim.Optimizer):
97-
full_state_dict = torch.load(
98-
self.path / str(self.previous_epoch) / self.optimizer_fname,
99-
mmap=True,
100-
weights_only=True,
101-
map_location="cpu",
102-
)
103-
set_optimizer_state_dict(
104-
model,
105-
optimizer,
106-
full_state_dict,
107-
options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True),
108-
)
109-
dist.barrier()
110-
111-
def save_checkpoint(
112-
self, model: torch.nn.Module, optimizer: torch.optim.Optimizer, epoch: int
113-
):
114-
model_state_dict = get_model_state_dict(
115-
model, options=StateDictOptions(full_state_dict=True, cpu_offload=True)
116-
)
117-
optimizer_state_dict = get_optimizer_state_dict(
118-
model,
119-
optimizer,
120-
options=StateDictOptions(full_state_dict=True, cpu_offload=True),
121-
)
122-
123-
if dist.get_rank() == 0:
124-
# Only rank 0 saves the checkpoint
125-
checkpoint_dir = self.path / str(epoch)
126-
os.makedirs(checkpoint_dir, exist_ok=True)
127-
torch.save(model_state_dict, checkpoint_dir / self.model_fname)
128-
torch.save(optimizer_state_dict, checkpoint_dir / self.optimizer_fname)
129-
130-
dist.barrier()
131-
132-
13342
def apply_fully_sharded(model: torch.nn.Module):
13443
fsdp_kwargs = {
13544
"mp_policy": MixedPrecisionPolicy(
@@ -168,7 +77,10 @@ def __init__(
16877
self.is_distributed = is_distributed
16978
self.local_rank = local_rank
17079
self.world_size = world_size
171-
self.checkpointer = Checkpointer(
80+
checkpointer_class = (
81+
DistributedCheckpointer if is_distributed else SingleGPUCheckpointer
82+
)
83+
self.checkpointer = checkpointer_class(
17284
config["save_path"],
17385
try_load_last_checkpoint=config.get("resume_from_checkpoint", False),
17486
)
@@ -196,6 +108,8 @@ def setup_model(self):
196108
if hasattr(sub_module, "reset_parameters"):
197109
sub_module.reset_parameters()
198110
# todo: We need to make sure we're loading lm_head and embed_tokens after this reset
111+
else:
112+
self.model.to(self.local_rank)
199113
self.verifier_lm_head = self.verifier_lm_head.to(self.local_rank)
200114

201115
def setup_optimizer(self):

0 commit comments

Comments
 (0)