Skip to content

Commit 345b314

Browse files
committed
Add support for new and old data formats
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
1 parent 583d5c9 commit 345b314

File tree

2 files changed

+92
-15
lines changed

2 files changed

+92
-15
lines changed

scripts/train_llama3_8b_drafter.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
Eagle3SampleFileDataset,
1414
create_collate_fn,
1515
split_files,
16+
standardize_data_v0,
17+
standardize_data_v1,
1618
)
1719
from speculators.train.distributed_batch_sampler import (
1820
MultipackDistributedBatchSamplerV2,
@@ -41,6 +43,7 @@ def setup_dataloader(
4143
world_size: int,
4244
local_rank: int,
4345
add_noise: bool = True,
46+
data_format_version: int = 1,
4447
):
4548
if add_noise:
4649
noise_transform = AddUniformNoise(
@@ -49,8 +52,15 @@ def setup_dataloader(
4952
else:
5053
noise_transform = None
5154

55+
standardize_fn = (
56+
standardize_data_v1 if data_format_version == 1 else standardize_data_v0
57+
)
58+
5259
dataset = Eagle3SampleFileDataset(
53-
file_list=file_list, max_len=TOTAL_SEQ_LEN, transform=noise_transform
60+
file_list=file_list,
61+
max_len=TOTAL_SEQ_LEN,
62+
transform=noise_transform,
63+
standardize_fn=standardize_fn,
5464
)
5565
batch_sampler = MultipackDistributedBatchSamplerV2(
5666
batch_max_length=TOTAL_SEQ_LEN,
@@ -118,8 +128,20 @@ def main(args: argparse.Namespace):
118128

119129
# Setup dataloaders
120130
train_files, val_files = split_files(args.data_path, ratio=0.9)
121-
train_loader = setup_dataloader(train_files, world_size, local_rank, add_noise=True)
122-
val_loader = setup_dataloader(val_files, world_size, local_rank, add_noise=False)
131+
train_loader = setup_dataloader(
132+
train_files,
133+
world_size,
134+
local_rank,
135+
add_noise=True,
136+
data_format_version=args.data_format_version,
137+
)
138+
val_loader = setup_dataloader(
139+
val_files,
140+
world_size,
141+
local_rank,
142+
add_noise=False,
143+
data_format_version=args.data_format_version,
144+
)
123145

124146
# Setup trainer
125147
trainer_config = TrainerConfig(
@@ -154,6 +176,7 @@ def parse_args():
154176
default="",
155177
help="One of 'trackio', 'wandb', 'tensorboard' or comma separated list of them",
156178
)
179+
parser.add_argument("--data-format-version", type=int, default=1)
157180
parser.add_argument("--log-dir", type=str, default="./logs")
158181
parser.add_argument("--run-name", type=str, default=None)
159182
parser.add_argument("--num-layers", type=int, default=1)

src/speculators/train/data.py

Lines changed: 66 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import math
33
import os
44
import random
5+
from collections.abc import Callable
56
from pathlib import Path
67
from typing import Any
78

@@ -114,14 +115,57 @@ def split_files(datapath: str, ratio: float = 0.9, seed: int = 0):
114115
return train_files, val_files
115116

116117

118+
# Data standardization functions
119+
StandardizeFnSig = Callable[[dict[str, Any]], dict[str, Any]]
120+
121+
122+
def standardize_data_v0(data: dict[str, Any]) -> dict[str, Any]:
123+
# v0 data format:
124+
# {
125+
# "input_ids": [seq_len],
126+
# "loss_mask": [seq_len],
127+
# "hidden_state": [seq_len, 3 * hidden_size],
128+
# "target": [seq_len, hidden_size],
129+
# }
130+
131+
return {
132+
"hidden_states": data["hidden_state"],
133+
"input_ids": data["input_ids"],
134+
"verifier_last_hidden_states": data["target"],
135+
"loss_mask": data["loss_mask"],
136+
}
137+
138+
139+
def standardize_data_v1(data: dict[str, Any]) -> dict[str, Any]:
140+
# v1 data format:
141+
# {
142+
# "input_ids": [seq_len],
143+
# "loss_mask": [seq_len],
144+
# "hidden_states": [
145+
# [seq_len, hidden_size],
146+
# [seq_len, hidden_size],
147+
# [seq_len, hidden_size],
148+
# ...
149+
# ],
150+
# }
151+
152+
return {
153+
"hidden_states": torch.cat(data["hidden_states"][:-1], dim=-1),
154+
"input_ids": data["input_ids"],
155+
"verifier_last_hidden_states": data["hidden_states"][-1],
156+
"loss_mask": data["loss_mask"],
157+
}
158+
159+
117160
class Eagle3SampleFileDataset(Dataset):
118161
def __init__(
119162
self,
120163
max_len: int,
121164
datapath: str | None = None,
122165
file_list: list[str] | None = None,
123-
transform=None,
166+
transform: TransformTensors | None = None,
124167
hidden_states_dtype=torch.float,
168+
standardize_fn: StandardizeFnSig = standardize_data_v1,
125169
):
126170
if datapath is not None and file_list is not None:
127171
raise ValueError("Only one of datapath or file_list may be provided")
@@ -134,6 +178,7 @@ def __init__(
134178
self.data: list[str] = file_list
135179
self.max_len = max_len
136180
self.transform = transform
181+
self.standardize_fn = standardize_fn
137182
self.hidden_states_dtype = hidden_states_dtype
138183
self.approx_lengths = self._compute_approx_lengths()
139184

@@ -155,24 +200,24 @@ def _compute_approx_lengths(self) -> list[int]:
155200
def __getitem__(self, index) -> BatchType:
156201
data = torch.load(self.data[index])
157202

158-
# todo: standardize names during data generation and then remove this
159-
data["hidden_states"] = data["hidden_state"]
160-
data["verifier_last_hidden_states"] = data["target"]
161-
del data["hidden_state"]
162-
del data["target"]
203+
data = self.standardize_fn(data)
204+
# data structure: {
205+
# "hidden_states": [seq_len, 3 * hidden_size],
206+
# "input_ids": [seq_len],
207+
# "verifier_last_hidden_states": [seq_len, hidden_size],
208+
# "loss_mask": [seq_len],
209+
# }
163210

164-
# todo: standardize dtypes during data generation and then remove this
211+
# Convert hidden states to the correct dtype
165212
data = {
166213
k: v.to(self.hidden_states_dtype) if "hidden_states" in k else v
167214
for k, v in data.items()
168215
}
169216

170-
seq_len = data["input_ids"].shape[0]
171217
# Add lengths tensor
218+
seq_len = data["input_ids"].shape[0]
172219
data["lengths"] = torch.tensor([seq_len], dtype=torch.long)
173-
174-
if self.transform:
175-
data = self.transform(data)
220+
# shape: [1]
176221

177222
data["position_ids"] = torch.arange(seq_len, dtype=torch.long)
178223
# shape: [seq_len]
@@ -186,6 +231,10 @@ def __getitem__(self, index) -> BatchType:
186231
# "position_ids": [seq_len],
187232
# }
188233

234+
# Apply transform
235+
if self.transform:
236+
data = self.transform(data)
237+
189238
# Note: shift_batch will reduce seq_len by 1
190239
return shift_batch(data)
191240

@@ -194,15 +243,20 @@ def create_collate_fn(max_len: int):
194243
def collate_fn(batch: list[BatchType]) -> BatchType:
195244
collated_data = {}
196245
for key in batch[0]:
246+
# Concatenate the tensors along the seq (0th) dimension
197247
collated_data[key] = torch.cat([b[key] for b in batch], dim=0)
248+
# shape: [total_seq_len, ...]
198249

199250
if key != "lengths":
251+
# Slice and pad on seq (0th) dimension to max_len
200252
collated_data[key] = slice_and_pad_to_length(
201253
collated_data[key], max_len
202254
).unsqueeze(0)
203255
# shape: [1, max_len, ...]
204256

205-
# Handle lengths update
257+
# Include lengths until while they fit in max_len
258+
# The last included length is (if necessary) truncated
259+
# Any additional lengths are discarded
206260
lengths = collated_data["lengths"]
207261
new_lengths = []
208262
cum_length = 0

0 commit comments

Comments
 (0)