From 5f1a1d64b0b668a4b139f0133b4db0fe1e8223b2 Mon Sep 17 00:00:00 2001 From: Michalis Kargakis Date: Tue, 10 Mar 2026 15:19:22 +0100 Subject: [PATCH] feat: make DEVICE_BATCH_SIZE configurable via env var Allow overriding DEVICE_BATCH_SIZE with an environment variable so users with smaller GPUs can reduce it without editing source code, e.g.: DEVICE_BATCH_SIZE=16 uv run train.py grad_accum_steps auto-adjusts to preserve TOTAL_BATCH_SIZE. Co-Authored-By: Claude Opus 4.6 --- train.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 6994fb9b..f647dbc9 100644 --- a/train.py +++ b/train.py @@ -447,7 +447,7 @@ def step(self): # Model size DEPTH = 8 # number of transformer layers -DEVICE_BATCH_SIZE = 128 # per-device batch size (reduce if OOM) +DEVICE_BATCH_SIZE = int(os.environ.get("DEVICE_BATCH_SIZE", 128)) # per-device batch size (reduce if OOM) # --------------------------------------------------------------------------- # Setup: tokenizer, model, optimizer, dataloader @@ -492,7 +492,8 @@ def build_model_config(depth): print(f"Estimated FLOPs per token: {num_flops_per_token:e}") tokens_per_fwdbwd = DEVICE_BATCH_SIZE * MAX_SEQ_LEN -assert TOTAL_BATCH_SIZE % tokens_per_fwdbwd == 0 +assert TOTAL_BATCH_SIZE % tokens_per_fwdbwd == 0, \ + f"DEVICE_BATCH_SIZE={DEVICE_BATCH_SIZE} does not evenly divide TOTAL_BATCH_SIZE={TOTAL_BATCH_SIZE} (tokens_per_fwdbwd={tokens_per_fwdbwd})" grad_accum_steps = TOTAL_BATCH_SIZE // tokens_per_fwdbwd optimizer = model.setup_optimizer(