Skip to content
/ BlaGPT Public

Experimental playground for benchmarking language model (LM) architectures, layers, and tricks on smaller datasets. Designed for flexible experimentation and exploration.

Notifications You must be signed in to change notification settings

erogol/BlaGPT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

161 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

BlaGPT

Experimental playground for benchmarking language model (LM) architectures, layers, and tricks on smaller datasets. Designed for flexible experimentation and exploration.

πŸ“š Technique Documentation

See the techniques/ directory for technical explanations and code snippets for various techniques and models implemented in the project.

Techniques under BlaGPT

BlaGPT is a flexible Transformer implementation that you can turn on/off following things in the config.

I basically do a greedy architecture search and add every new technique on top of the best model config and see if it improves the performance or not. I know it is now the best way but it is fun and help have at least some intuition about what works and what doesn't.

Multi-token prediction - link

Weight tying - link

Grouped query attention - link

Capping logits - link

QKV bias - link

Zero-init projection layer - link

Post and pre-RMSNorm - link

Setting base theta to 1_000_000 - llama3 - increased the final validation loss - best 3.3324

Z-loss regularization - link - increased the final validation loss by 0.02 - loss: 3.3527

KV-Shifting attention - link - seems to improve performance - loss: 3.3310 -> 3.3138 - peak memory consumption: 42858 MiB

Dilated Attention (LongNet) - link

Multi-Head Latent Attention - link - loss: 3.3479 - peak memory consumption: 42192 MiB

Per token output bias - link - loss: 3.3257 - peak memory consumption: 42120 MiB

DyT Norm - link - didn't really work. Loss stuck too high

Forgetting Transformer (Vanilla and Pro vers) - link - vanilla loss: 3.3243, pro loss: OOM

Multi-Token Attention - link - loss: 3.3357 - peak memory: 42136 MiB

Differential Attention - link - best_model_loss: 3.2411 -> loss: 3.2460 - peak memory: 41521 MiB

Softpick - link - loss: 3.3446 - peak memory: 59417 MiB

Canon Layer - link - loss: 3.3217 - peak memory: 43199 MiB

Parallel Transformer Block - link - loss: 3.3473 - peak memory: 40302 MiB

Per Layer Token Embedding - link - loss: 3.2411 - peak memory: 40916 MiB

PolyNorm - link - best_model_loss: 3.2411 -> loss: 3.3017 - peak memory: 40895 MiB

PolyReLU - link - best_model_loss: 3.2411 -> loss: 3.2642 - peak memory: 40890 MiB

TOP loss - paper | explanation - best_model_loss: 3.2411 -> loss: 3.2636 - peak memory: 47816 MiB

Simplified RoPe - link - best_model_loss: 3.2411 -> loss: 3.2620 - peak memory: 43585 MiB - step_avg: 388.54ms

Gated Attention - paper | explanation - best_model_loss: 3.2411 -> new_best_model_loss: 3.2327 - peak memory: 45968 MiB - step_avg: 413.01ms

ResFormer (Plus) - paper | explanation - best_model_loss: 3.2327 -> model_loss: 3.3538 - peak memory: 38223 MiB - step_avg: 326.32ms

Engram (My Simple Variant) - paper | explanation - best_model_loss: 3.2327 -> new_best_model_loss: 3.2296 - peak memory: 50488 MiB - step_avg: 504.09ms

πŸ‘‘ Differential Attention v2 - paper | explanation - best_model_loss: 3.2296 -> new_best_model_loss: 3.2274 - peak memory: 52829 MiB - step_avg: 535.16

Other Models

MegaByte - link - loss: 3.810

FTP (heavily modified) - link - loss: 3.901

Rene - link - loss: 3.340

Rwkv7 - link - loss: 4.450

Zamba2 - link - Zamba2 > Rene > Rwkv7

Hourglass Transformer (modified) - link - Hourglass > MegaByte > FTP - loss: 3.710

Hymba - link - train step time is significantly slower than the transformers. Best validation loss so far: 4.7505

Tokenformer (in BlaGPT model) - link - loss: 3.390

LLaDa (dLLM) - link - val-loss: 8.6930, xentropy-loss: 4.2891 (comparable to other models and estimated by llada_validation_cross_entropy.py),

Avey - link - loss: 3.323, peak memory: 51962 MiB (batch size 8), step_time: 2871ms (very slow to train and uses >3x more memory than other models)

LFM2 - link - TBD

Kimi Delta Attention (1:3 interleaved Full Attention) - link - best_model_loss: 3.2411 -> loss: 3.2532, peak_memory:47391, step_time: 568.1ms

Byte-Level Models

Hourglass Transformer (modified) - link - val_loss:1.0048 train_time:2671049ms step_avg:524.76ms

AUNet - link - val_loss:1.1502 train_time:7246104ms step_avg:1423.60ms

SpaceByte - link - val_loss:1.6755 train_time:2154923ms step_avg:423.36ms peak memory consumption: 27781 MiB

HNet - link - val_loss:1.4554 train_time:2207809ms step_avg:433.75ms peak memory consumption: 23948 MiB

Optimizers

PaLMForeachSOAP - link - almost 2 times slower than Adam but the best results

Ademamix - link - Unstable even after trying different learning rates.

Adopt - link - straight up Nan

CAdamW - link - loss: 3.3517

AdamW with independent weight decay - link - loss: 3.320

Adam - loss: 3.3224

AdamW - loss: 3.3310, peak VRAM: 42053 MiB, step_time: 533ms

DeMo - link - Saves 7 GB per GPU, loss is higher than baseline, step time is slower than Adam - loss: 3.4676, peak VRAM: 41534 MiB, step_time: 820ms

Adam-Mini - link - loss is higher than Adam and AdamW and also slower ??, saved a bit of VRAM - loss: 3.3324, peak VRAM: 41534 MiB, step_time: 610ms

MARS - link - loss: 3.3459, peak VRAM: 40953 MiB, step_time: 628ms

Muon - link - loss: 3.2923, peak VRAM: 40332MB, step_time: 620.24ms

AdaMuon - paper | explanation - Adaptive Muon with second-moment estimation (default optimizer)

BiClip - link - (not working well) loss: 7.2292, peak VRAM: 39751 MiB, step_time: 510ms

NorMuon - paper | explanation - best_model_loss: 3.2411 -> loss: 3.4630, peak VRAM: 44154 MiB, step_time: 387.46 ms

Cautious Weight Decay - paper | explanation - best_model_loss: 3.2327 -> loss: 3.2334, peak VRAM: 45971 MiB, step_time: 434.2 ms

Adding a New Model

  • Implement the model
  • Return the loss in the forward function
  • Add model to model_registry.py
  • And start training

See one of the implementations for details.

Training

BlaGPT provides two training scripts:

Script Purpose Data Sources
train.py Quick experimentation with pre-tokenized data Binary shards only
train_flex.py Flexible training with custom datasets/tokenizers Binary shards, HF streaming, byte-level

Quick Start (train.py)

For rapid architecture benchmarking with the default FineWeb10B dataset:

# Get the pre-tokenized data
python data/fineweb10B_cached.py

# Start training
torchrun --standalone --nproc_per_node=8 bla_gpt/train.py --run_name my_experiment --model_name blagpt

Flexible Training (train_flex.py)

For custom datasets, different tokenizers, or HuggingFace streaming:

Using Pre-tokenized Binary Shards (Default)

torchrun --standalone --nproc_per_node=8 bla_gpt/train_flex.py \
    --model_name blagpt \
    --run_name my_experiment

Using HuggingFace Streaming Datasets

Stream data directly from HuggingFace Hub - no pre-download required:

# FineWeb with GPT-2 tokenizer (tiktoken)
torchrun --standalone --nproc_per_node=8 bla_gpt/train_flex.py \
    --model_name blagpt \
    --run_name hf_streaming_test \
    --use_hf_streaming \
    --hf_dataset "HuggingFaceFW/fineweb" \
    --hf_dataset_config "sample-10BT"

# C4 dataset with HuggingFace tokenizer
torchrun --standalone --nproc_per_node=8 bla_gpt/train_flex.py \
    --model_name blagpt \
    --run_name c4_experiment \
    --use_hf_streaming \
    --hf_dataset "allenai/c4" \
    --hf_dataset_config "en" \
    --tokenizer_backend huggingface \
    --tokenizer_name "gpt2"

# Custom dataset with custom text column
torchrun --standalone --nproc_per_node=8 bla_gpt/train_flex.py \
    --model_name blagpt \
    --run_name custom_data \
    --use_hf_streaming \
    --hf_dataset "username/my-dataset" \
    --hf_text_column "content"

Dataset Options

Option Default Description
--use_hf_streaming False Enable HuggingFace streaming mode
--hf_dataset HuggingFaceFW/fineweb HuggingFace dataset name or path
--hf_dataset_config sample-10BT Dataset configuration/subset
--hf_split train Training split name
--hf_text_column text Column containing text data
--hf_val_dataset Same as train Separate validation dataset (optional)
--hf_val_dataset_config Same as train Validation dataset config
--hf_val_split train Validation split name
--hf_val_samples 1000 Number of samples for validation (first N of val split)
--hf_shuffle_buffer 10000 Shuffle buffer size for streaming

Train/Validation Split: By default, the first 1000 samples are used for validation and training skips these samples (no overlap). When using a separate validation dataset or split, this behavior is automatically disabled.

Tokenizer Options

Option Default Description
--tokenizer_backend tiktoken Backend: tiktoken or huggingface
--tokenizer_name gpt2 Tokenizer name (encoding or model path)

tiktoken encodings: gpt2, cl100k_base, o200k_base

HuggingFace tokenizers: Any model name, e.g., meta-llama/Llama-2-7b-hf, mistralai/Mistral-7B-v0.1

Learning Rate Finder (Optional)

Run before training to find optimal learning rate:

torchrun --standalone --nproc_per_node=8 bla_gpt/find_lr.py --model_name blagpt

# Output
Results:
Steepest gradient learning rate: 3.31e-06
Elbow point learning rate: 1.20e-01
Plot saved to: logs/lr_finder_blagpt/lr_finder_plot.png
Results saved to: logs/lr_finder_blagpt/lr_finder_results.pt

Best Model So Far

  • Check best_model_config.py for the best model configuration so far.

  • You can run the training with the best model config by running:

torchrun --standalone --nproc_per_node=8 train.py --run_name best_model --model_name best

Acknowledgements

The initial code is based on

Nano GPT - link

Modded NanoGPT - link

Thanks to @xumingyu2021 for memory friendly implementation of the Differential Attention

About

Experimental playground for benchmarking language model (LM) architectures, layers, and tricks on smaller datasets. Designed for flexible experimentation and exploration.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 2

  •  
  •  

Languages