Skip to content

Latest commit

 

History

History
277 lines (169 loc) · 12.6 KB

File metadata and controls

277 lines (169 loc) · 12.6 KB

BlaGPT

Experimental playground for benchmarking language model (LM) architectures, layers, and tricks on smaller datasets. Designed for flexible experimentation and exploration.

📚 Technique Documentation

See the techniques/ directory for technical explanations and code snippets for various techniques and models implemented in the project.

Techniques under BlaGPT

BlaGPT is a flexible Transformer implementation that you can turn on/off following things in the config.

I basically do a greedy architecture search and add every new technique on top of the best model config and see if it improves the performance or not. I know it is now the best way but it is fun and help have at least some intuition about what works and what doesn't.

Multi-token prediction - link

Weight tying - link

Grouped query attention - link

Capping logits - link

QKV bias - link

Zero-init projection layer - link

Post and pre-RMSNorm - link

Setting base theta to 1_000_000 - llama3 - increased the final validation loss - best 3.3324

Z-loss regularization - link - increased the final validation loss by 0.02 - loss: 3.3527

KV-Shifting attention - link - seems to improve performance - loss: 3.3310 -> 3.3138 - peak memory consumption: 42858 MiB

Dilated Attention (LongNet) - link

Multi-Head Latent Attention - link - loss: 3.3479 - peak memory consumption: 42192 MiB

Per token output bias - link - loss: 3.3257 - peak memory consumption: 42120 MiB

DyT Norm - link - didn't really work. Loss stuck too high

Forgetting Transformer (Vanilla and Pro vers) - link - vanilla loss: 3.3243, pro loss: OOM

Multi-Token Attention - link - loss: 3.3357 - peak memory: 42136 MiB

Differential Attention - link - best_model_loss: 3.2411 -> loss: 3.2460 - peak memory: 41521 MiB

Softpick - link - loss: 3.3446 - peak memory: 59417 MiB

Canon Layer - link - loss: 3.3217 - peak memory: 43199 MiB

Parallel Transformer Block - link - loss: 3.3473 - peak memory: 40302 MiB

Per Layer Token Embedding - link - loss: 3.2411 - peak memory: 40916 MiB

PolyNorm - link - best_model_loss: 3.2411 -> loss: 3.3017 - peak memory: 40895 MiB

PolyReLU - link - best_model_loss: 3.2411 -> loss: 3.2642 - peak memory: 40890 MiB

TOP loss - paper | explanation - best_model_loss: 3.2411 -> loss: 3.2636 - peak memory: 47816 MiB

Simplified RoPe - link - best_model_loss: 3.2411 -> loss: 3.2620 - peak memory: 43585 MiB - step_avg: 388.54ms

Gated Attention - paper | explanation - best_model_loss: 3.2411 -> new_best_model_loss: 3.2327 - peak memory: 45968 MiB - step_avg: 413.01ms

ResFormer (Plus) - paper | explanation - best_model_loss: 3.2327 -> model_loss: 3.3538 - peak memory: 38223 MiB - step_avg: 326.32ms

Engram (My Simple Variant) - paper | explanation - best_model_loss: 3.2327 -> new_best_model_loss: 3.2296 - peak memory: 50488 MiB - step_avg: 504.09ms

👑 Differential Attention v2 - paper | explanation - best_model_loss: 3.2296 -> new_best_model_loss: 3.2274 - peak memory: 52829 MiB - step_avg: 535.16

Other Models

MegaByte - link - loss: 3.810

FTP (heavily modified) - link - loss: 3.901

Rene - link - loss: 3.340

Rwkv7 - link - loss: 4.450

Zamba2 - link - Zamba2 > Rene > Rwkv7

Hourglass Transformer (modified) - link - Hourglass > MegaByte > FTP - loss: 3.710

Hymba - link - train step time is significantly slower than the transformers. Best validation loss so far: 4.7505

Tokenformer (in BlaGPT model) - link - loss: 3.390

LLaDa (dLLM) - link - val-loss: 8.6930, xentropy-loss: 4.2891 (comparable to other models and estimated by llada_validation_cross_entropy.py),

Avey - link - loss: 3.323, peak memory: 51962 MiB (batch size 8), step_time: 2871ms (very slow to train and uses >3x more memory than other models)

LFM2 - link - TBD

Kimi Delta Attention (1:3 interleaved Full Attention) - link - best_model_loss: 3.2411 -> loss: 3.2532, peak_memory:47391, step_time: 568.1ms

Byte-Level Models

Hourglass Transformer (modified) - link - val_loss:1.0048 train_time:2671049ms step_avg:524.76ms

AUNet - link - val_loss:1.1502 train_time:7246104ms step_avg:1423.60ms

SpaceByte - link - val_loss:1.6755 train_time:2154923ms step_avg:423.36ms peak memory consumption: 27781 MiB

HNet - link - val_loss:1.4554 train_time:2207809ms step_avg:433.75ms peak memory consumption: 23948 MiB

Optimizers

PaLMForeachSOAP - link - almost 2 times slower than Adam but the best results

Ademamix - link - Unstable even after trying different learning rates.

Adopt - link - straight up Nan

CAdamW - link - loss: 3.3517

AdamW with independent weight decay - link - loss: 3.320

Adam - loss: 3.3224

AdamW - loss: 3.3310, peak VRAM: 42053 MiB, step_time: 533ms

DeMo - link - Saves 7 GB per GPU, loss is higher than baseline, step time is slower than Adam - loss: 3.4676, peak VRAM: 41534 MiB, step_time: 820ms

Adam-Mini - link - loss is higher than Adam and AdamW and also slower ??, saved a bit of VRAM - loss: 3.3324, peak VRAM: 41534 MiB, step_time: 610ms

MARS - link - loss: 3.3459, peak VRAM: 40953 MiB, step_time: 628ms

Muon - link - loss: 3.2923, peak VRAM: 40332MB, step_time: 620.24ms

AdaMuon - paper | explanation - Adaptive Muon with second-moment estimation (default optimizer)

BiClip - link - (not working well) loss: 7.2292, peak VRAM: 39751 MiB, step_time: 510ms

NorMuon - paper | explanation - best_model_loss: 3.2411 -> loss: 3.4630, peak VRAM: 44154 MiB, step_time: 387.46 ms

Cautious Weight Decay - paper | explanation - best_model_loss: 3.2327 -> loss: 3.2334, peak VRAM: 45971 MiB, step_time: 434.2 ms

Adding a New Model

  • Implement the model
  • Return the loss in the forward function
  • Add model to model_registry.py
  • And start training

See one of the implementations for details.

Training

BlaGPT provides two training scripts:

Script Purpose Data Sources
train.py Quick experimentation with pre-tokenized data Binary shards only
train_flex.py Flexible training with custom datasets/tokenizers Binary shards, HF streaming, byte-level

Quick Start (train.py)

For rapid architecture benchmarking with the default FineWeb10B dataset:

# Get the pre-tokenized data
python data/fineweb10B_cached.py

# Start training
torchrun --standalone --nproc_per_node=8 bla_gpt/train.py --run_name my_experiment --model_name blagpt

Flexible Training (train_flex.py)

For custom datasets, different tokenizers, or HuggingFace streaming:

Using Pre-tokenized Binary Shards (Default)

torchrun --standalone --nproc_per_node=8 bla_gpt/train_flex.py \
    --model_name blagpt \
    --run_name my_experiment

Using HuggingFace Streaming Datasets

Stream data directly from HuggingFace Hub - no pre-download required:

# FineWeb with GPT-2 tokenizer (tiktoken)
torchrun --standalone --nproc_per_node=8 bla_gpt/train_flex.py \
    --model_name blagpt \
    --run_name hf_streaming_test \
    --use_hf_streaming \
    --hf_dataset "HuggingFaceFW/fineweb" \
    --hf_dataset_config "sample-10BT"

# C4 dataset with HuggingFace tokenizer
torchrun --standalone --nproc_per_node=8 bla_gpt/train_flex.py \
    --model_name blagpt \
    --run_name c4_experiment \
    --use_hf_streaming \
    --hf_dataset "allenai/c4" \
    --hf_dataset_config "en" \
    --tokenizer_backend huggingface \
    --tokenizer_name "gpt2"

# Custom dataset with custom text column
torchrun --standalone --nproc_per_node=8 bla_gpt/train_flex.py \
    --model_name blagpt \
    --run_name custom_data \
    --use_hf_streaming \
    --hf_dataset "username/my-dataset" \
    --hf_text_column "content"

Dataset Options

Option Default Description
--use_hf_streaming False Enable HuggingFace streaming mode
--hf_dataset HuggingFaceFW/fineweb HuggingFace dataset name or path
--hf_dataset_config sample-10BT Dataset configuration/subset
--hf_split train Training split name
--hf_text_column text Column containing text data
--hf_val_dataset Same as train Separate validation dataset (optional)
--hf_val_dataset_config Same as train Validation dataset config
--hf_val_split train Validation split name
--hf_val_samples 1000 Number of samples for validation (first N of val split)
--hf_shuffle_buffer 10000 Shuffle buffer size for streaming

Train/Validation Split: By default, the first 1000 samples are used for validation and training skips these samples (no overlap). When using a separate validation dataset or split, this behavior is automatically disabled.

Tokenizer Options

Option Default Description
--tokenizer_backend tiktoken Backend: tiktoken or huggingface
--tokenizer_name gpt2 Tokenizer name (encoding or model path)

tiktoken encodings: gpt2, cl100k_base, o200k_base

HuggingFace tokenizers: Any model name, e.g., meta-llama/Llama-2-7b-hf, mistralai/Mistral-7B-v0.1

Learning Rate Finder (Optional)

Run before training to find optimal learning rate:

torchrun --standalone --nproc_per_node=8 bla_gpt/find_lr.py --model_name blagpt

# Output
Results:
Steepest gradient learning rate: 3.31e-06
Elbow point learning rate: 1.20e-01
Plot saved to: logs/lr_finder_blagpt/lr_finder_plot.png
Results saved to: logs/lr_finder_blagpt/lr_finder_results.pt

Best Model So Far

  • Check best_model_config.py for the best model configuration so far.

  • You can run the training with the best model config by running:

torchrun --standalone --nproc_per_node=8 train.py --run_name best_model --model_name best

Acknowledgements

The initial code is based on

Nano GPT - link

Modded NanoGPT - link

Thanks to @xumingyu2021 for memory friendly implementation of the Differential Attention