Experimental playground for benchmarking language model (LM) architectures, layers, and tricks on smaller datasets. Designed for flexible experimentation and exploration.
See the techniques/ directory for technical explanations and code snippets for various techniques and models implemented in the project.
BlaGPT is a flexible Transformer implementation that you can turn on/off following things in the config.
I basically do a greedy architecture search and add every new technique on top of the best model config and see if it improves the performance or not. I know it is now the best way but it is fun and help have at least some intuition about what works and what doesn't.
Multi-token prediction - link
Weight tying - link
Grouped query attention - link
Capping logits - link
QKV bias - link
Zero-init projection layer - link
Post and pre-RMSNorm - link
Setting base theta to 1_000_000 - llama3 - increased the final validation loss - best 3.3324
Z-loss regularization - link - increased the final validation loss by 0.02 - loss: 3.3527
KV-Shifting attention - link - seems to improve performance - loss: 3.3310 -> 3.3138 - peak memory consumption: 42858 MiB
Dilated Attention (LongNet) - link
Multi-Head Latent Attention - link - loss: 3.3479 - peak memory consumption: 42192 MiB
Per token output bias - link - loss: 3.3257 - peak memory consumption: 42120 MiB
DyT Norm - link - didn't really work. Loss stuck too high
Forgetting Transformer (Vanilla and Pro vers) - link - vanilla loss: 3.3243, pro loss: OOM
Multi-Token Attention - link - loss: 3.3357 - peak memory: 42136 MiB
Differential Attention - link - best_model_loss: 3.2411 -> loss: 3.2460 - peak memory: 41521 MiB
Softpick - link - loss: 3.3446 - peak memory: 59417 MiB
Canon Layer - link - loss: 3.3217 - peak memory: 43199 MiB
Parallel Transformer Block - link - loss: 3.3473 - peak memory: 40302 MiB
Per Layer Token Embedding - link - loss: 3.2411 - peak memory: 40916 MiB
PolyNorm - link - best_model_loss: 3.2411 -> loss: 3.3017 - peak memory: 40895 MiB
PolyReLU - link - best_model_loss: 3.2411 -> loss: 3.2642 - peak memory: 40890 MiB
TOP loss - paper | explanation - best_model_loss: 3.2411 -> loss: 3.2636 - peak memory: 47816 MiB
Simplified RoPe - link - best_model_loss: 3.2411 -> loss: 3.2620 - peak memory: 43585 MiB - step_avg: 388.54ms
Gated Attention - paper | explanation - best_model_loss: 3.2411 -> new_best_model_loss: 3.2327 - peak memory: 45968 MiB - step_avg: 413.01ms
ResFormer (Plus) - paper | explanation - best_model_loss: 3.2327 -> model_loss: 3.3538 - peak memory: 38223 MiB - step_avg: 326.32ms
Engram (My Simple Variant) - paper | explanation - best_model_loss: 3.2327 -> new_best_model_loss: 3.2296 - peak memory: 50488 MiB - step_avg: 504.09ms
π Differential Attention v2 - paper | explanation - best_model_loss: 3.2296 -> new_best_model_loss: 3.2274 - peak memory: 52829 MiB - step_avg: 535.16
MegaByte - link - loss: 3.810
FTP (heavily modified) - link - loss: 3.901
Rene - link - loss: 3.340
Rwkv7 - link - loss: 4.450
Zamba2 - link - Zamba2 > Rene > Rwkv7
Hourglass Transformer (modified) - link - Hourglass > MegaByte > FTP - loss: 3.710
Hymba - link - train step time is significantly slower than the transformers. Best validation loss so far: 4.7505
Tokenformer (in BlaGPT model) - link - loss: 3.390
LLaDa (dLLM) - link - val-loss: 8.6930, xentropy-loss: 4.2891 (comparable to other models and estimated by llada_validation_cross_entropy.py),
Avey - link - loss: 3.323, peak memory: 51962 MiB (batch size 8), step_time: 2871ms (very slow to train and uses >3x more memory than other models)
LFM2 - link - TBD
Kimi Delta Attention (1:3 interleaved Full Attention) - link - best_model_loss: 3.2411 -> loss: 3.2532, peak_memory:47391, step_time: 568.1ms
Hourglass Transformer (modified) - link - val_loss:1.0048 train_time:2671049ms step_avg:524.76ms
AUNet - link - val_loss:1.1502 train_time:7246104ms step_avg:1423.60ms
SpaceByte - link - val_loss:1.6755 train_time:2154923ms step_avg:423.36ms peak memory consumption: 27781 MiB
HNet - link - val_loss:1.4554 train_time:2207809ms step_avg:433.75ms peak memory consumption: 23948 MiB
PaLMForeachSOAP - link - almost 2 times slower than Adam but the best results
Ademamix - link - Unstable even after trying different learning rates.
Adopt - link - straight up Nan
CAdamW - link - loss: 3.3517
AdamW with independent weight decay - link - loss: 3.320
Adam - loss: 3.3224
AdamW - loss: 3.3310, peak VRAM: 42053 MiB, step_time: 533ms
DeMo - link - Saves 7 GB per GPU, loss is higher than baseline, step time is slower than Adam - loss: 3.4676, peak VRAM: 41534 MiB, step_time: 820ms
Adam-Mini - link - loss is higher than Adam and AdamW and also slower ??, saved a bit of VRAM - loss: 3.3324, peak VRAM: 41534 MiB, step_time: 610ms
MARS - link - loss: 3.3459, peak VRAM: 40953 MiB, step_time: 628ms
Muon - link - loss: 3.2923, peak VRAM: 40332MB, step_time: 620.24ms
AdaMuon - paper | explanation - Adaptive Muon with second-moment estimation (default optimizer)
BiClip - link - (not working well) loss: 7.2292, peak VRAM: 39751 MiB, step_time: 510ms
NorMuon - paper | explanation - best_model_loss: 3.2411 -> loss: 3.4630, peak VRAM: 44154 MiB, step_time: 387.46 ms
Cautious Weight Decay - paper | explanation - best_model_loss: 3.2327 -> loss: 3.2334, peak VRAM: 45971 MiB, step_time: 434.2 ms
- Implement the model
- Return the loss in the forward function
- Add model to
model_registry.py - And start training
See one of the implementations for details.
BlaGPT provides two training scripts:
| Script | Purpose | Data Sources |
|---|---|---|
train.py |
Quick experimentation with pre-tokenized data | Binary shards only |
train_flex.py |
Flexible training with custom datasets/tokenizers | Binary shards, HF streaming, byte-level |
For rapid architecture benchmarking with the default FineWeb10B dataset:
# Get the pre-tokenized data
python data/fineweb10B_cached.py
# Start training
torchrun --standalone --nproc_per_node=8 bla_gpt/train.py --run_name my_experiment --model_name blagptFor custom datasets, different tokenizers, or HuggingFace streaming:
torchrun --standalone --nproc_per_node=8 bla_gpt/train_flex.py \
--model_name blagpt \
--run_name my_experimentStream data directly from HuggingFace Hub - no pre-download required:
# FineWeb with GPT-2 tokenizer (tiktoken)
torchrun --standalone --nproc_per_node=8 bla_gpt/train_flex.py \
--model_name blagpt \
--run_name hf_streaming_test \
--use_hf_streaming \
--hf_dataset "HuggingFaceFW/fineweb" \
--hf_dataset_config "sample-10BT"
# C4 dataset with HuggingFace tokenizer
torchrun --standalone --nproc_per_node=8 bla_gpt/train_flex.py \
--model_name blagpt \
--run_name c4_experiment \
--use_hf_streaming \
--hf_dataset "allenai/c4" \
--hf_dataset_config "en" \
--tokenizer_backend huggingface \
--tokenizer_name "gpt2"
# Custom dataset with custom text column
torchrun --standalone --nproc_per_node=8 bla_gpt/train_flex.py \
--model_name blagpt \
--run_name custom_data \
--use_hf_streaming \
--hf_dataset "username/my-dataset" \
--hf_text_column "content"| Option | Default | Description |
|---|---|---|
--use_hf_streaming |
False |
Enable HuggingFace streaming mode |
--hf_dataset |
HuggingFaceFW/fineweb |
HuggingFace dataset name or path |
--hf_dataset_config |
sample-10BT |
Dataset configuration/subset |
--hf_split |
train |
Training split name |
--hf_text_column |
text |
Column containing text data |
--hf_val_dataset |
Same as train | Separate validation dataset (optional) |
--hf_val_dataset_config |
Same as train | Validation dataset config |
--hf_val_split |
train |
Validation split name |
--hf_val_samples |
1000 |
Number of samples for validation (first N of val split) |
--hf_shuffle_buffer |
10000 |
Shuffle buffer size for streaming |
Train/Validation Split: By default, the first 1000 samples are used for validation and training skips these samples (no overlap). When using a separate validation dataset or split, this behavior is automatically disabled.
| Option | Default | Description |
|---|---|---|
--tokenizer_backend |
tiktoken |
Backend: tiktoken or huggingface |
--tokenizer_name |
gpt2 |
Tokenizer name (encoding or model path) |
tiktoken encodings: gpt2, cl100k_base, o200k_base
HuggingFace tokenizers: Any model name, e.g., meta-llama/Llama-2-7b-hf, mistralai/Mistral-7B-v0.1
Run before training to find optimal learning rate:
torchrun --standalone --nproc_per_node=8 bla_gpt/find_lr.py --model_name blagpt
# Output
Results:
Steepest gradient learning rate: 3.31e-06
Elbow point learning rate: 1.20e-01
Plot saved to: logs/lr_finder_blagpt/lr_finder_plot.png
Results saved to: logs/lr_finder_blagpt/lr_finder_results.pt-
Check
best_model_config.pyfor the best model configuration so far. -
You can run the training with the best model config by running:
torchrun --standalone --nproc_per_node=8 train.py --run_name best_model --model_name bestThe initial code is based on
Nano GPT - link
Modded NanoGPT - link
Thanks to @xumingyu2021 for memory friendly implementation of the Differential Attention