Skip to content

Commit f111aea

Browse files
committed
Merge branch 'lm_workload' of github.com:mlcommons/algorithmic-efficiency into lm_workload
2 parents d7a885c + 6171b2d commit f111aea

File tree

4 files changed

+14
-14
lines changed

4 files changed

+14
-14
lines changed

algoperf/workloads/lm/input_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def get_lm_dataset(
126126
ds = ds.map(lambda x: {'inputs': x['inputs'],
127127
'targets': x['targets'],
128128
'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0)})
129-
ds = ds.take(1000).prefetch(tf.data.experimental.AUTOTUNE) # todo(kasimbeg): set final size of validation
129+
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
130130
elif split == 'validation':
131131
ds = batch_with_padding(
132132
sequences_ds,
@@ -139,6 +139,6 @@ def get_lm_dataset(
139139
ds = ds.map(lambda x: {'inputs': x['inputs'],
140140
'targets': x['targets'],
141141
'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0)})
142-
ds = ds.take(1000).prefetch(tf.data.experimental.AUTOTUNE) # todo(kasimbeg): set final size
142+
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
143143

144144
return ds

algoperf/workloads/lm/workload.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,11 @@ def num_train_examples(self) -> int:
6161

6262
@property
6363
def num_eval_train_examples(self) -> int:
64-
return 500 # Subset for evaluation. # TODO(kasimbeg): update
64+
return 10_000 # Subset for evaluation.
6565

6666
@property
6767
def num_validation_examples(self) -> int:
68-
return 500 # TODO(kasimbeg update)
68+
return 100_000 # sequences
6969

7070
@property
7171
def num_test_examples(self) -> int:
@@ -85,11 +85,11 @@ def train_stddev(self):
8585

8686
@property
8787
def max_allowed_runtime_sec(self) -> int:
88-
return 3600 * 5 # 4 hours TODO(kasimbeg): update
88+
return 3600 * 14 # 14 hours TODO(kasimbeg): update
8989

9090
@property
9191
def eval_period_time_sec(self) -> int:
92-
return 600 # 10 minutes TODO(kasimbeg): update
92+
return 1200 # 20 minutes TODO(kasimbeg): update
9393

9494
@property
9595
def step_hint(self) -> int:

algorithms/target_setting_algorithms/fineweb_edu_lm/jax_nadamw_target_setting.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,8 @@ def init_optimizer_state(
170170
del rng
171171

172172
def jax_cosine_warmup(step_hint: int, hyperparameters):
173-
# Create learning rate schedule.
174173
step_hint = 0.75 * step_hint
174+
# Create learning rate schedule.
175175
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
176176
warmup_fn = optax.linear_schedule(
177177
init_value=0.0,
@@ -343,7 +343,7 @@ def update_params(
343343
)
344344

345345
# Log loss, grad_norm.
346-
if global_step % 1 == 0 and workload.metrics_logger is not None:
346+
if global_step % 100 == 0 and workload.metrics_logger is not None:
347347
workload.metrics_logger.append_scalar_metrics(
348348
{'loss': loss.item(), 'grad_norm': grad_norm.item()}, global_step
349349
)
Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
[
22
{
33
"dropout_rate": 0.0,
4-
"label_smoothing": 0.1,
5-
"learning_rate": 0.0003955553491092581,
6-
"one_minus_beta1": 0.06124602712,
7-
"beta2": 0.9535169492059872,
8-
"weight_decay": 0.03268700808664715,
9-
"warmup_factor": 0.0375
4+
"label_smoothing": 0.0,
5+
"learning_rate": 0.00038418421332238876,
6+
"one_minus_beta1": 0.01564758865,
7+
"beta2": 0.992362328914093,
8+
"weight_decay": 0.25551270901641954,
9+
"warmup_factor": 0.05
1010
}
1111
]

0 commit comments

Comments
 (0)