From bc2edf706b4cbcb9f62dfb0e5d0f4657161282c3 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 24 Apr 2024 13:54:50 -0700 Subject: [PATCH] wip --- config/gpt2_small_fast_public.yaml | 33 ++++++++++++++++++++++++++ scripts/clean_old_checkpoints.py | 37 ++++++++++++++++++++++++------ 2 files changed, 63 insertions(+), 7 deletions(-) create mode 100644 config/gpt2_small_fast_public.yaml diff --git a/config/gpt2_small_fast_public.yaml b/config/gpt2_small_fast_public.yaml new file mode 100644 index 000000000..1b466ef91 --- /dev/null +++ b/config/gpt2_small_fast_public.yaml @@ -0,0 +1,33 @@ +data: + train_urls: + - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz" + validation_urls: + - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz" + cache_dir: "gs://levanter-data/tokenized/openwebtext/" + tokenizer: "gpt2" +model: + type: gpt2 + hidden_dim: 768 + num_heads: 12 + num_layers: 12 + seq_len: 1024 + gradient_checkpointing: true + scale_attn_by_inverse_layer_idx: true +trainer: + tracker: + - type: wandb + project: "levanter" + tags: [ "openwebtext", "gpt2", "itest"] + + mp: p=f32,c=bfloat16 + model_axis_size: 1 + per_device_parallelism: -1 + + train_batch_size: 256 + num_train_steps: 10000 + ray: + auto_start_cluster: false +optimizer: + learning_rate: 1E-3 + weight_decay: 0.1 + warmup: 0.01 diff --git a/scripts/clean_old_checkpoints.py b/scripts/clean_old_checkpoints.py index 0ebe38b9b..dbd68eb27 100644 --- a/scripts/clean_old_checkpoints.py +++ b/scripts/clean_old_checkpoints.py @@ -4,9 +4,13 @@ # python clean_old_checkpoints.py gs://my-bucket/my-dir | xargs -I {} gsutil -m rm -r {} import os import sys +import time +from datetime import datetime, timezone import fsspec +AGE = 30 # days + def is_dir_of_checkpoints(path): fs = fsspec.filesystem("gcs") @@ -45,16 +49,35 @@ def list_deletable_directories(base_dir): # Add all checkpoint directories except the ones we need to keep for path in checkpoint_paths: - if path != max_complete_checkpoint and path != max_000_checkpoint: - yield path + if path == max_complete_checkpoint or path == max_000_checkpoint: + continue + + try: + new = False + for file in ["metadata.json", "worker-0.cert"]: + details = fs.ls(f"{path}/{file}", detail=True) + if details: + mtime = details[0]["mtime"] + age = (datetime.now(timezone.utc) - mtime).days + if age < AGE: + new = True + break + + if new: + continue + + except FileNotFoundError: + pass + + yield path + # Usage example: if __name__ == "__main__": - if len(sys.argv) != 2: + if len(sys.argv) < 2: print("Usage: python clean_old_checkpoints.py ") sys.exit(1) - base_dir = sys.argv[1] - - for path in list_deletable_directories(base_dir): - print(f"gs://{path}") + for base_dir in sys.argv[1:]: + for path in list_deletable_directories(base_dir): + print(f"gs://{path}")