diff --git a/train.py b/train.py index 6994fb9b..f647dbc9 100644 --- a/train.py +++ b/train.py @@ -447,7 +447,7 @@ def step(self): # Model size DEPTH = 8 # number of transformer layers -DEVICE_BATCH_SIZE = 128 # per-device batch size (reduce if OOM) +DEVICE_BATCH_SIZE = int(os.environ.get("DEVICE_BATCH_SIZE", 128)) # per-device batch size (reduce if OOM) # --------------------------------------------------------------------------- # Setup: tokenizer, model, optimizer, dataloader @@ -492,7 +492,8 @@ def build_model_config(depth): print(f"Estimated FLOPs per token: {num_flops_per_token:e}") tokens_per_fwdbwd = DEVICE_BATCH_SIZE * MAX_SEQ_LEN -assert TOTAL_BATCH_SIZE % tokens_per_fwdbwd == 0 +assert TOTAL_BATCH_SIZE % tokens_per_fwdbwd == 0, \ + f"DEVICE_BATCH_SIZE={DEVICE_BATCH_SIZE} does not evenly divide TOTAL_BATCH_SIZE={TOTAL_BATCH_SIZE} (tokens_per_fwdbwd={tokens_per_fwdbwd})" grad_accum_steps = TOTAL_BATCH_SIZE // tokens_per_fwdbwd optimizer = model.setup_optimizer(