Skip to content

Commit 6171b2d

Browse files
committed
update eval split sizes for lm workload and target setting point
1 parent 65369f2 commit 6171b2d

File tree

4 files changed

+14
-14
lines changed

4 files changed

+14
-14
lines changed

algoperf/workloads/lm/input_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def get_lm_dataset(
120120
ds = ds.map(lambda x: {'inputs': x['inputs'],
121121
'targets': x['targets'],
122122
'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0)})
123-
ds = ds.take(1000).prefetch(tf.data.experimental.AUTOTUNE) # todo(kasimbeg): set final size of validation
123+
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
124124
elif split == 'validation':
125125
ds = batch_with_padding(
126126
sequences_ds,
@@ -133,6 +133,6 @@ def get_lm_dataset(
133133
ds = ds.map(lambda x: {'inputs': x['inputs'],
134134
'targets': x['targets'],
135135
'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0)})
136-
ds = ds.take(1000).prefetch(tf.data.experimental.AUTOTUNE) # todo(kasimbeg): set final size
136+
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
137137

138138
return ds

algoperf/workloads/lm/workload.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,11 @@ def num_train_examples(self) -> int:
6161

6262
@property
6363
def num_eval_train_examples(self) -> int:
64-
return 500 # Subset for evaluation. # TODO(kasimbeg): update
64+
return 10_000 # Subset for evaluation.
6565

6666
@property
6767
def num_validation_examples(self) -> int:
68-
return 500 # TODO(kasimbeg update)
68+
return 100_000 # sequences
6969

7070
@property
7171
def num_test_examples(self) -> int:
@@ -85,11 +85,11 @@ def train_stddev(self):
8585

8686
@property
8787
def max_allowed_runtime_sec(self) -> int:
88-
return 3600 * 5 # 4 hours TODO(kasimbeg): update
88+
return 3600 * 14 # 14 hours TODO(kasimbeg): update
8989

9090
@property
9191
def eval_period_time_sec(self) -> int:
92-
return 600 # 10 minutes TODO(kasimbeg): update
92+
return 1200 # 20 minutes TODO(kasimbeg): update
9393

9494
@property
9595
def step_hint(self) -> int:

algorithms/target_setting_algorithms/fineweb_edu_lm/jax_nadamw_target_setting.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,8 @@ def init_optimizer_state(
170170
del rng
171171

172172
def jax_cosine_warmup(step_hint: int, hyperparameters):
173-
# Create learning rate schedule.
174173
step_hint = 0.75 * step_hint
174+
# Create learning rate schedule.
175175
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
176176
warmup_fn = optax.linear_schedule(
177177
init_value=0.0,
@@ -343,7 +343,7 @@ def update_params(
343343
)
344344

345345
# Log loss, grad_norm.
346-
if global_step % 1 == 0 and workload.metrics_logger is not None:
346+
if global_step % 100 == 0 and workload.metrics_logger is not None:
347347
workload.metrics_logger.append_scalar_metrics(
348348
{'loss': loss.item(), 'grad_norm': grad_norm.item()}, global_step
349349
)
Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
[
22
{
33
"dropout_rate": 0.0,
4-
"label_smoothing": 0.1,
5-
"learning_rate": 0.0003955553491092581,
6-
"one_minus_beta1": 0.06124602712,
7-
"beta2": 0.9535169492059872,
8-
"weight_decay": 0.03268700808664715,
9-
"warmup_factor": 0.0375
4+
"label_smoothing": 0.0,
5+
"learning_rate": 0.00038418421332238876,
6+
"one_minus_beta1": 0.01564758865,
7+
"beta2": 0.992362328914093,
8+
"weight_decay": 0.25551270901641954,
9+
"warmup_factor": 0.05
1010
}
1111
]

0 commit comments

Comments
 (0)