Skip to content

Commit d7a885c

Browse files
committed
Porting workload input pipeline to torch
- Added `limit_tf_threads` parameter to `pytorch_init` to control TensorFlow threading based on workload type. Dataloader was going OOM otherwise. - Updated input pipeline to support "None" for weights (for memory). - Modified Transformer model's `forward` method to optionally return loss during training. Should be better to fuse the loss later. - Adjusted torch LM workload configuration for model dimensions and parameters to match jax. - Updated transformers version in `pyproject.toml`, older version seems unavailable.
1 parent 65369f2 commit d7a885c

File tree

6 files changed

+90
-121
lines changed

6 files changed

+90
-121
lines changed

algoperf/pytorch_utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def pytorch_setup() -> Tuple[bool, int, torch.device, int]:
2727
return use_pytorch_ddp, rank, device, n_gpus
2828

2929

30-
def pytorch_init(use_pytorch_ddp: bool, rank: int, profiler: Profiler) -> None:
30+
def pytorch_init(use_pytorch_ddp: bool, rank: int, profiler: Profiler, limit_tf_threads = True) -> None:
3131
# Make sure no GPU memory is preallocated to Jax.
3232
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
3333
# Only use CPU for Jax to avoid memory issues.
@@ -39,18 +39,16 @@ def pytorch_init(use_pytorch_ddp: bool, rank: int, profiler: Profiler) -> None:
3939

4040
if use_pytorch_ddp:
4141
# Avoid tf input pipeline creating too many threads.
42-
if rank != 0:
42+
if rank != 0 and limit_tf_threads:
4343
tf.config.threading.set_intra_op_parallelism_threads(1)
4444
tf.config.threading.set_inter_op_parallelism_threads(1)
4545

4646
torch.cuda.set_device(rank)
4747
profiler.set_local_rank(rank)
4848
# Only log once (for local rank == 0).
4949
if rank != 0:
50-
5150
def logging_pass(*args):
5251
pass
53-
5452
logging.info = logging_pass
5553
# Initialize the process group.
5654
dist.init_process_group('nccl')

algoperf/workloads/lm/input_pipeline.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,13 @@ def get_lm_dataset(
107107
repeated_sequences_dataset = shuffled_sequences_ds.repeat()
108108
ds = repeated_sequences_dataset.batch(
109109
global_batch_size, drop_remainder=False
110-
).prefetch(tf.data.experimental.AUTOTUNE)
110+
)
111+
ds = ds.map(lambda x: {
112+
'inputs': x['inputs'],
113+
'targets': x['targets'],
114+
'weights': None,
115+
})
116+
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
111117
elif split == 'eval_train':
112118
ds = batch_with_padding(
113119
sequences_ds,

algoperf/workloads/lm/lm_pytorch/plainlm_model.py

Lines changed: 42 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def __init__(self, cfg):
159159
if cfg.tie_embeddings:
160160
self.tie_weights()
161161

162-
def forward(self, x):
162+
def forward(self, x, targets=None):
163163
# x: (bsz, seqlen)
164164
x = self.embed_tokens(x) # (bsz, seqlen, dim)
165165
L = x.shape[1]
@@ -178,7 +178,12 @@ def forward(self, x):
178178

179179
for layer in self.layers:
180180
x = layer(x, freqs_cis) # (bsz, seqlen, dim)
181-
return self.lm_head(self.out_norm(x)) # (bsz, seqlen, vocab_size)
181+
out = self.lm_head(self.out_norm(x)) # (bsz, seqlen, vocab_size)
182+
if targets is not None:
183+
loss = F.cross_entropy(
184+
out.view(-1, out.size(-1)), targets.view(-1), ignore_index=-100)
185+
return out, loss
186+
return out
182187

183188
def predict(self, x, k=1):
184189
"""Generate k tokens autoregressively.
@@ -190,18 +195,14 @@ def predict(self, x, k=1):
190195
Returns:
191196
Tuple of (input_ids, predicted_ids)
192197
"""
193-
# For debugging
194-
predictions = []
195-
196-
batch_size = x.shape[0]
197-
seq_len = x.shape[1]
198198

199199
# Store original input
200200
original_input = x.clone()
201201
generated_input = x.clone()
202202

203203
# Generate k tokens autoregressively
204204
for i in range(k):
205+
205206
# Get logits for the entire sequence
206207
logits = self(generated_input)
207208

@@ -212,24 +213,20 @@ def predict(self, x, k=1):
212213
# This is a common issue - the model gets stuck repeating the last token
213214
last_token_id = generated_input[:, -1]
214215
next_token_logits.scatter_(1, last_token_id.unsqueeze(1), float('-inf'))
215-
216-
# Print top 5 tokens for debugging
217-
if i == 0:
218-
print("\nPyTorch detailed prediction:")
219-
top5_values, top5_indices = torch.topk(next_token_logits[0], 5)
220-
for j, (idx, val) in enumerate(zip(top5_indices.tolist(), top5_values.tolist())):
221-
prob = torch.softmax(next_token_logits[0], dim=-1)[idx].item()
222-
print(f" Top {j+1}: Token {idx}, logit={val:.2f}, prob={prob:.6f}")
223-
216+
224217
# Get the most likely token
225218
next_token = torch.argmax(next_token_logits, dim=-1)
226-
predictions.append(next_token.item())
227219

228220
# Append the predicted token to the sequence
229221
next_token = next_token.unsqueeze(1) # Add sequence dimension
230222
generated_input = torch.cat([generated_input, next_token], dim=1)
231223

232-
print(f" Full predictions step by step: {predictions}")
224+
# For debugging, print predictions for the first item in the batch
225+
print("\nPyTorch detailed prediction (first item in batch):")
226+
predicted_sequence = generated_input[0, -k:].tolist()
227+
print(f" Predicted token IDs: {predicted_sequence}")
228+
for i, token_id in enumerate(predicted_sequence):
229+
print(f" Step {i+1}: Predicted token {token_id}")
233230

234231
# Return all tokens, not just the last k
235232
return original_input, generated_input[:, -k:]
@@ -269,30 +266,43 @@ def count_params(self, non_embedding=True):
269266
def main():
270267
print("Initializing transformer model and running forward pass...")
271268

272-
seq_length = 512
269+
seq_length = 1024
273270

274271
# Define model configuration
275272
config = ModelConfig(
276-
vocab_size=32000, # Common vocab size for tokenizers like BPE or SentencePiece
273+
vocab_size=50257, # Common vocab size for tokenizers like BPE or SentencePiece
277274
seq_len=seq_length, # Maximum sequence length
278-
dim=768, # Embedding dimension
275+
dim=1024, # Embedding dimension
279276
expand=4.0, # MLP expansion factor
280277
n_layers=12, # Number of transformer layers
281-
n_heads=12, # Number of attention heads
278+
n_heads=8, # Number of attention heads
282279
rmsnorm_eps=1e-6, # RMSNorm epsilon
283280
tie_embeddings=True # Tie embedding and output weights
284281
)
285282

286-
def tie_weights(self):
287-
self.lm_head.weight = self.embed_tokens.weight
283+
# Instantiate the model
284+
model = Transformer(config)
285+
print(f"Model has {model.count_params():,} parameters.")
288286

289-
def count_params(self, non_embedding=True):
290-
n_params = sum(p.numel() for p in self.parameters())
291-
if non_embedding:
292-
n_params -= self.embed_tokens.weight.numel()
293-
if (not self.lm_head.weight
294-
is self.embed_tokens.weight): # if no weight tying
295-
n_params -= self.lm_head.weight.numel()
296-
return n_params
287+
# Create some random input data
288+
batch_size = 2
289+
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_length))
290+
291+
# Move data to the same device as the model
292+
if torch.cuda.is_available():
293+
input_ids = input_ids.cuda()
294+
295+
# Run a forward pass
296+
print(f"Running forward pass with input shape: {input_ids.shape}")
297+
logits = model(input_ids)
298+
print(f"Output logits shape: {logits.shape}")
297299

300+
# Run prediction
301+
print("Running prediction...")
302+
original_input, predicted_ids = model.predict(input_ids[:, :10], k=5)
303+
print(f"Original input shape for prediction: {original_input.shape}")
304+
print(f"Predicted IDs shape: {predicted_ids.shape}")
305+
print(f"Predicted IDs: {predicted_ids}")
298306

307+
if __name__ == "__main__":
308+
main()

algoperf/workloads/lm/lm_pytorch/workload.py

Lines changed: 36 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
Transformer,
1515
)
1616
from algoperf.workloads.lm.workload import BaseLmWorkload
17+
from algoperf.workloads.lm.input_pipeline import get_data_iter
1718

1819
USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup()
1920

@@ -37,10 +38,11 @@ def init_model_fn(
3738
cfg = ModelConfig(
3839
vocab_size=self._vocab_size,
3940
seq_len=self._seq_len,
40-
dim=512, # Model dimension
41-
expand=4, # MLP expansion factor
42-
n_layers=6, # Number of transformer layers
43-
n_heads=8, # Number of attention heads
41+
dim=self._emb_dim, # Model dimension
42+
expand=self._mlp_dim // self._emb_dim, # MLP expansion factor
43+
# FIXME(rka97): fix expansion factor
44+
n_layers=self._n_layers, # Number of transformer layers
45+
n_heads=self._n_heads, # Number of attention heads
4446
rmsnorm_eps=1e-6,
4547
tie_embeddings=True
4648
)
@@ -65,7 +67,7 @@ def model_fn(
6567
mode: spec.ForwardPassMode,
6668
rng: spec.RandomState,
6769
update_batch_norm: bool,
68-
dropout_rate: None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
70+
dropout_rate: float = 0.0) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
6971

7072
del model_state, rng, update_batch_norm, dropout_rate
7173
model = params
@@ -87,10 +89,8 @@ def _build_input_queue(
8789
num_batches: Optional[int] = None,
8890
repeat_final_dataset: bool = False) -> Iterator[Dict[str, spec.Tensor]]:
8991
"""Build an input queue for the given split."""
90-
from algoperf.workloads.lm.input_pipeline import get_lm_dataset
9192
local_batch_size = global_batch_size // N_GPUS
92-
93-
loader = get_lm_dataset(
93+
loader = get_data_iter(
9494
data_rng=data_rng,
9595
split=split,
9696
data_dir=data_dir,
@@ -99,100 +99,54 @@ def _build_input_queue(
9999
)
100100
if USE_PYTORCH_DDP:
101101
loader = islice(loader, RANK, None, N_GPUS)
102-
seq_len = self._seq_len
103-
weights = None
104-
105102
dtype = torch.int32
106-
is_train = split == 'train'
107-
108103
for batch in loader:
109-
inputs = batch['inputs']
110-
targets = batch['targets']
111-
112-
if USE_PYTORCH_DDP:
113-
if not is_train:
114-
# During eval, the batch size of the remainder might be different
115-
per_device_batch_size = torch.tensor(
116-
targets.shape[0], dtype=dtype, device=DEVICE)
117-
dist.broadcast(per_device_batch_size, src=0)
118-
local_batch_size = per_device_batch_size.item()
119-
# Broadcast to all devices
120-
#dist.broadcast(inputs, src=0)
121-
#dist.broadcast(targets, src=0)
122-
123-
if weights is None:
124-
weights = torch.ones((local_batch_size, seq_len), device=DEVICE)
125104
batch = {
126-
'inputs': torch.tensor(inputs, device=DEVICE, dtype=dtype),
127-
'targets': torch.tensor(targets, device=DEVICE, dtype=dtype),
128-
'weights': weights,
105+
'inputs': torch.tensor(batch['inputs'], device=DEVICE, dtype=dtype),
106+
'targets': torch.tensor(batch['targets'], device=DEVICE, dtype=torch.int64),
107+
'weights': None,
129108
}
130109
yield batch
131110

132111
def is_output_params(self, param_name: str) -> bool:
133112
"""Return whether the given parameter is an output parameter."""
134113
return 'lm_head.weight' in param_name or 'lm_head.bias' in param_name
135114

136-
def _eval_batch(self,
137-
params: spec.ParameterContainer,
138-
batch: Dict[str, spec.Tensor],
139-
model_state: spec.ModelAuxiliaryState,
140-
rng: spec.RandomState) -> spec.Tensor:
141-
"""Evaluate the model on a single batch."""
142-
model = params
143-
logits, _ = self.model_fn(
144-
model, batch, model_state, spec.ForwardPassMode.EVAL, rng, False)
145-
146-
# Handle both one-hot and token ID targets
147-
targets = batch['targets']
148-
if targets.dim() == 3: # one-hot
149-
loss = -torch.sum(targets * torch.nn.functional.log_softmax(logits, dim=-1))
150-
else: # token IDs
151-
# TODO(kasimbeg): before deleting make sure we have defined self.weighted_cross_entropy so that we can call the shared workload _eval_batch.
152-
loss = torch.nn.functional.cross_entropy(
153-
logits.view(-1, logits.size(-1)),
154-
targets.view(-1),
155-
reduction='sum'
156-
)
157-
return loss
158-
159-
def loss_fn(
160-
self,
161-
label_batch: spec.Tensor,
162-
logits_batch: spec.Tensor,
163-
mask_batch: Optional[spec.Tensor] = None,
164-
label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]:
115+
# FIXME(rka97): Implement label smoothing
116+
def compute_weighted_cross_entropy(self, logits: spec.Tensor, labels: spec.Tensor, weights: spec.Tensor, label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]:
165117
"""Compute cross-entropy loss for language modeling in PyTorch."""
166-
vocab_size = logits_batch.shape[-1]
118+
vocab_size = logits.size(-1)
167119

168-
if len(label_batch.shape) == len(logits_batch.shape):
120+
if len(labels.shape) == len(logits.shape):
169121
# One-hot labels
170-
log_probs = torch.nn.functional.log_softmax(logits_batch, dim=-1)
171-
loss = -torch.sum(label_batch * log_probs, dim=-1)
122+
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
123+
loss = -torch.sum(labels * log_probs, dim=-1)
172124
else:
173125
# Dense labels
174126
loss = torch.nn.functional.cross_entropy(
175-
logits_batch,
176-
label_batch,
127+
logits.view(-1, vocab_size),
128+
labels.view(-1),
177129
reduction='none')
178-
if mask_batch is not None:
179-
loss = loss * mask_batch
130+
loss = loss.view_as(labels)
131+
132+
if weights is not None:
133+
loss = loss * weights
180134

181-
n_valid = mask_batch.sum() if mask_batch is not None else label_batch.shape[0]
135+
n_valid = weights.sum() if weights is not None else torch.tensor(labels.numel(), dtype=torch.float32, device=labels.device)
182136
return {
183137
'summed': loss.sum(),
184138
'n_valid_examples': n_valid,
185-
'per_example': loss
139+
'per_example': loss,
186140
}
187141

188-
def _normalize_eval_metrics(
189-
self, num_examples: int, total_metrics: Dict[str, Any]
190-
) -> Dict[str, float]:
191-
"""Normalize eval metrics."""
192-
del num_examples
193-
if USE_PYTORCH_DDP:
194-
for metric in total_metrics.values():
195-
dist.all_reduce(metric)
196-
total_metrics = {k: v.item() for k, v in total_metrics.items()}
197-
eval_denominator = total_metrics.pop('denominator')
198-
return jax.tree.map(lambda x: float(x / eval_denominator), total_metrics)
142+
def _normalize_eval_metrics(
143+
self, num_examples: int, total_metrics: Dict[str, Any]
144+
) -> Dict[str, float]:
145+
"""Normalize eval metrics."""
146+
del num_examples
147+
if USE_PYTORCH_DDP:
148+
for metric in total_metrics.values():
149+
dist.all_reduce(metric)
150+
total_metrics = {k: v.item() for k, v in total_metrics.items()}
151+
eval_denominator = total_metrics.pop('denominator')
152+
return jax.tree.map(lambda x: float(x / eval_denominator), total_metrics)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ librispeech_conformer = [
9090
"pydub==0.25.1",
9191
]
9292
wmt = ["sentencepiece==0.2.0", "tensorflow-text==2.19.0"]
93-
lm = ["transformers==4.25.4", "datasets==3.6.0"]
93+
lm = ["transformers==4.26", "datasets==3.6.0"]
9494

9595
# Frameworks
9696
jax_core_deps = [

submission_runner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -784,7 +784,8 @@ def main(_):
784784
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256'
785785

786786
if FLAGS.framework == 'pytorch':
787-
pytorch_init(USE_PYTORCH_DDP, RANK, profiler)
787+
limit_tf_threads = (base_workload != 'lm')
788+
pytorch_init(USE_PYTORCH_DDP, RANK, profiler, limit_tf_threads=limit_tf_threads)
788789

789790
# TODO: remove once issue resolved.
790791
if FLAGS.pytorch_eval_num_workers != 0:

0 commit comments

Comments
 (0)