Skip to content

Commit 0618974

Browse files
Merge pull request #682 from mlcommons/dev
dev -> main
2 parents 6b188ba + 68a0d18 commit 0618974

File tree

7 files changed

+39
-17
lines changed

7 files changed

+39
-17
lines changed

DOCUMENTATION.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -400,12 +400,12 @@ The currently eight fixed workloads are:
400400

401401
| | **Task** | **Dataset** | **Model** | **Loss** | **Metric** | Validation<br>**Target** | Test<br>**Target** | Maximum<br>**Runtime** <br>(in secs) |
402402
|------------|-------------------------------|-------------|-------------------------|----------|------------|--------------------------|----------------------|------------------------|
403-
| **1** | Clickthrough rate prediction | Criteo 1TB | DLRMsmall | CE | CE | 0.123649 | 0.126060 | 21,600 |
404-
| **2** | MRI reconstruction | fastMRI | U-Net | L1 | SSIM | 0.7344 | 0.741652 | 10,800 |
405-
| **3<br>4** | Image classification | ImageNet | ResNet-50<br>ViT | CE | ER | 0.22569<br>0.22691 | 0.3440<br>0.3481 | 111,600 <br> 111,600 |
406-
| **5<br>6** | Speech recognition | LibriSpeech | Conformer<br>DeepSpeech | CTC | WER | 0.078477<br>0.1162 | 0.046973<br>0.068093 | <br>72,000 |
407-
| **7** | Molecular property prediction | OGBG | GNN | CE | mAP | 0.28098 | 0.268729 | 12,000 |
408-
| **8** | Translation | WMT | Transformer | CE | BLEU | 30.8491 | 30.7219 | 80,000 |
403+
| **1** | Clickthrough rate prediction | Criteo 1TB | DLRMsmall | CE | CE | 0.123735 | 0.126041 | 7,703 |
404+
| **2** | MRI reconstruction | fastMRI | U-Net | L1 | SSIM | 0.723653 | 0.740633 | 8,859 |
405+
| **3<br>4** | Image classification | ImageNet | ResNet-50<br>ViT | CE | ER | 0.22569<br>0.22691 | 0.3440<br>0.3481 | 63,008 <br> 77,520 |
406+
| **5<br>6** | Speech recognition | LibriSpeech | Conformer<br>DeepSpeech | CTC | WER | 0.085884<br>0.119936 | 0.052981<br>0.074143 | 61,068<br>55,506 |
407+
| **7** | Molecular property prediction | OGBG | GNN | CE | mAP | 0.28098 | 0.268729 | 18,477 |
408+
| **8** | Translation | WMT | Transformer | CE | BLEU | 30.8491 | 30.7219 | 48,151 |
409409

410410
#### Randomized workloads
411411

algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@ class ImagenetVitWorkload(BaseImagenetVitWorkload, ImagenetResNetWorkload):
2424
def initialized(self, key: spec.RandomState,
2525
model: nn.Module) -> spec.ModelInitState:
2626
input_shape = (1, 224, 224, 3)
27-
variables = jax.jit(model.init)({'params': key}, jnp.ones(input_shape))
27+
params_rng, dropout_rng = jax.random.split(key)
28+
variables = jax.jit(
29+
model.init)({'params': params_rng, 'dropout': dropout_rng},
30+
jnp.ones(input_shape))
2831
model_state, params = variables.pop('params')
2932
return params, model_state
3033

algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -234,10 +234,11 @@ def init_model_fn(
234234
self._train_model = models.Transformer(model_config)
235235
eval_config = replace(model_config, deterministic=True)
236236
self._eval_model = models.Transformer(eval_config)
237-
initial_variables = jax.jit(self._eval_model.init)(
238-
rng,
239-
jnp.ones(input_shape, jnp.float32),
240-
jnp.ones(target_shape, jnp.float32))
237+
params_rng, dropout_rng = jax.random.split(rng)
238+
initial_variables = jax.jit(
239+
self._eval_model.init)({'params': params_rng, 'dropout': dropout_rng},
240+
jnp.ones(input_shape, jnp.float32),
241+
jnp.ones(target_shape, jnp.float32))
241242

242243
initial_params = initial_variables['params']
243244
self._param_shapes = param_utils.jax_param_shapes(initial_params)

datasets/dataset_setup.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -291,9 +291,22 @@ def download_criteo1tb(data_dir,
291291
stream=True)
292292

293293
all_days_zip_filepath = os.path.join(tmp_criteo_dir, 'all_days.zip')
294-
with open(all_days_zip_filepath, 'wb') as f:
295-
for chunk in download_request.iter_content(chunk_size=1024):
296-
f.write(chunk)
294+
download = True
295+
if os.path.exists(all_days_zip_filepath):
296+
while True:
297+
overwrite = input('File already exists {}.\n Overwrite? (Y/n)'.format(
298+
all_days_zip_filepath)).lower()
299+
if overwrite in ['y', 'n']:
300+
break
301+
logging.info('Invalid response. Try again.')
302+
if overwrite == 'n':
303+
logging.info(f'Skipping download to {all_days_zip_filepath}')
304+
download = False
305+
306+
if download:
307+
with open(all_days_zip_filepath, 'wb') as f:
308+
for chunk in download_request.iter_content(chunk_size=1024):
309+
f.write(chunk)
297310

298311
unzip_cmd = f'unzip {all_days_zip_filepath} -d {tmp_criteo_dir}'
299312
logging.info(f'Running Criteo 1TB unzip command:\n{unzip_cmd}')
@@ -679,6 +692,7 @@ def main(_):
679692
if any(s in tmp_dir for s in bad_chars):
680693
raise ValueError(f'Invalid temp_dir: {tmp_dir}.')
681694
data_dir = os.path.abspath(os.path.expanduser(data_dir))
695+
tmp_dir = os.path.abspath(os.path.expanduser(tmp_dir))
682696
logging.info('Downloading data to %s...', data_dir)
683697

684698
if FLAGS.all or FLAGS.criteo1tb:

docker/scripts/startup.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,8 @@ VALID_WORKLOADS=("criteo1tb" "imagenet_resnet" "imagenet_resnet_silu" "imagenet_
157157
"criteo1tb_resnet" "criteo1tb_layernorm" "criteo1tb_embed_init" \
158158
"wmt" "wmt_post_ln" "wmt_attention_temp" "wmt_glu_tanh" \
159159
"librispeech_deepspeech" "librispeech_conformer" "mnist" \
160-
"conformer_layernorm" "conformer_attention_temperature" \
161-
"conformer_gelu" "fastmri_model_size" "fastmri_tanh" \
160+
"librispeech_conformer_layernorm" "librispeech_conformer_attention_temperature" \
161+
"librispeech_conformer_gelu" "fastmri_model_size" "fastmri_tanh" \
162162
"librispeech_deepspeech_tanh" \
163163
"librispeech_deepspeech_no_resnet" "librispeech_deepspeech_norm_and_spec_aug"
164164
"fastmri_layernorm" "ogbg_gelu" "ogbg_silu" "ogbg_model_size")

reference_algorithms/target_setting_algorithms/get_batch_size.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ def get_batch_size(workload_name):
1515
return 512
1616
elif workload_name == 'imagenet_vit':
1717
return 1024
18+
elif workload_name == 'imagenet_vit_glu':
19+
return 512
1820
elif workload_name == 'librispeech_conformer':
1921
return 256
2022
elif workload_name == 'librispeech_deepspeech':

submission_runner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,9 @@ def train_once(
387387
train_state['test_goal_reached'] = (
388388
workload.has_reached_test_target(latest_eval_result) or
389389
train_state['test_goal_reached'])
390-
390+
goals_reached = (
391+
train_state['validation_goal_reached'] and
392+
train_state['test_goal_reached'])
391393
# Save last eval time.
392394
eval_end_time = get_time()
393395
train_state['last_eval_time'] = eval_end_time

0 commit comments

Comments
 (0)