nano-transformer is a from-scratch LLM codebase that starts with core Transformer primitives and grows into a compact but genuinely end-to-end training stack.
It is meant for people who want more than a toy GPT file, but less ceremony than a large production framework. The repo includes tokenizer training, dense and MoE model variants, modern attention features, staged pretraining, checkpointing, evaluation, and sampling in a layout that is explicit enough to study and extend comfortably.
- Build the model stack directly in PyTorch instead of hiding it behind a large framework.
- Keep the code organized enough for real experimentation: dense baselines, GQA, RoPE scaling, KV cache, MoE, schedules, and optimizer variants all live in a coherent package.
- Preserve readability. The code is designed to be inspectable, testable, and modifiable without turning into a single monolithic training script.
- Cover the full workflow:
- train or load a BPE tokenizer
- prepare token caches
- pretrain a model
- resume from checkpoints
- evaluate
- sample text
| Area | Implemented |
|---|---|
| Tokenization | GPT-2-style byte-level BPE, tokenizer loading, special-token handling, tokenizer training utilities |
| Core model | Embedding, RMSNorm, SwiGLU, merged QKV, RoPE, QK-Norm, grouped-query attention, optional tied embeddings |
| Attention variants | SDPA fast path, explicit fallback path, sliding-window attention, global prefix tokens, RoPE scaling (linear, yarn) |
| Inference | Preallocated KV cache, prefill, decode_step, cache branching, greedy / sample / top-k / top-p decoding |
| MoE | Vectorized token-choice routing, shared experts, capacity control, dropped-token accounting, aux loss, z-loss, bias balancing |
| Training | Document-aware batching, EoS-prefixed document tokenization, max-document-length truncation, gradient accumulation, staged schedules |
| Optimizers | AdamW, Muon, hybrid Muon+AdamW grouping, cautious weight decay, LR-tied weight decay scheduling |
| Runtime | TF32 support, autocast, optional torch.compile, DDP support, checkpoint save/load, resume support |
| Validation | Extensive unit tests, snapshot tests, script-level coverage for pretrain / eval / sample flows |
Current local validation for this repository:
uv run pytest->117 passed, 2 skipped- script flows are covered in
tests/test_scripts.py - the skipped tests are memory-limit checks that depend on local resource limits
flowchart LR
A["Raw text corpora"] --> B["BPE tokenizer<br/>vocab + merges"]
B --> C["Tokenized dataset caches<br/>+ document ids"]
C --> D["TransformerLM / MoE model"]
D --> E["Trainer<br/>staged schedules + optimizer"]
E --> F["Checkpoints<br/>latest.pt / best.pt"]
F --> G["eval.py"]
F --> H["sample.py"]
src/nano_transformer/
__init__.py
model.py
moe.py
inference/
decoding.py
kv_cache.py
tokenization/
tokenizer.py
bpe.py
gpt2.py
training/
checkpointing.py
data.py
optim.py
trainer.py
scripts/
pretrain.py
sample.py
eval.py
train_bpe.py
tests/
test_model.py
test_inference.py
test_moe.py
test_optimizer.py
test_trainer.py
test_scripts.py
...
This repository uses uv and is packaged with a src/ layout.
- Python
>= 3.11 uv- PyTorch is resolved from
pyproject.tomlbased on platform
brew install uv
uv syncIf you prefer not to sync eagerly, uv run ... will also resolve and run commands in the project environment.
For editable package-style imports outside uv run, you can also do:
uv pip install -e .uv run pytestThis distinction matters.
- the package code under
src/nano_transformer/ - runnable scripts under
scripts/ - the test suite and fixtures under
tests/ - tokenizer artifacts for the named datasets:
The large raw text corpora are intentionally not committed. By default, the named dataset shortcuts expect local files such as:
data/TinyStoriesV2-GPT4-train.txtdata/TinyStoriesV2-GPT4-valid.txtdata/owt_train.txtdata/owt_valid.txt
Generated artifacts are also local:
- token caches (
*.npy) - document-id caches
- checkpoints
- run directories under
out/
If you do not want to use the named dataset shortcuts, you can point the training script directly at your own files with --train-text-path, --valid-text-path, --vocab-path, and --merges-path.
scripts/pretrain.py accepts:
tinystoriesopenwebtextowt
Example:
mkdir -p data
cd data
curl -L -o TinyStoriesV2-GPT4-train.txt \
https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStoriesV2-GPT4-train.txt
curl -L -o TinyStoriesV2-GPT4-valid.txt \
https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStoriesV2-GPT4-valid.txt
curl -L -o owt_train.txt.gz \
https://huggingface.co/datasets/stanford-cs336/owt-sample/resolve/main/owt_train.txt.gz
gunzip -f owt_train.txt.gz
curl -L -o owt_valid.txt.gz \
https://huggingface.co/datasets/stanford-cs336/owt-sample/resolve/main/owt_valid.txt.gz
gunzip -f owt_valid.txt.gzThe training data path is more than "flatten token stream and slice windows":
- documents are tokenized with a document-EoS prefix when possible
- token caches are fingerprinted so tokenizer changes invalidate stale caches
- optional document-id arrays are stored alongside token caches
- batch starts can prefer document boundaries
- cross-document targets can be masked out
- documents can be truncated to
--max-document-length
These behaviors live primarily in src/nano_transformer/training/data.py.
uv run pytestuv run scripts/pretrain.py \
--dataset tinystories \
--context-length 128 \
--d-model 256 \
--num-layers 4 \
--num-heads 8 \
--batch-size 16 \
--gradient-accumulation-steps 4 \
--max-iters 1000 \
--device cpuuv run scripts/sample.py \
--run-dir out/pretrain/tinystories \
--prompt "Once upon a time" \
--max-new-tokens 64 \
--strategy top_k \
--top-k 40uv run scripts/eval.py \
--run-dir out/pretrain/tinystories \
--eval-iters 20Training writes a run directory such as:
out/pretrain/tinystories/
latest.pt
best.pt
run_config.json
trainer_config.json
metrics.jsonl
What these files are for:
latest.pt: latest checkpointbest.pt: best validation checkpoint tracked so farrun_config.json: dataset paths, tokenizer paths, special tokens, and model config used for the runtrainer_config.json: trainer/runtime configuration for the runmetrics.jsonl: per-step training and evaluation metrics such as loss, CE loss, tokens/sec, grad norm, active schedule state, and MoE counters
There are two practical ways to train or refresh tokenizer artifacts.
This is the most flexible path:
from nano_transformer import train_bpe
vocab, merges = train_bpe(
input_path="data/TinyStoriesV2-GPT4-train.txt",
vocab_size=10000,
special_tokens=["<|endoftext|>"],
)scripts/train_bpe.py contains small helper functions for serializing tokenizer artifacts for the bundled datasets. It is intentionally lightweight, and it is not a full CLI wrapper yet.
uv run scripts/train_bpe.pyIf you want custom tokenizer-training behavior, prefer the library API or edit the helper functions in scripts/train_bpe.py.
scripts/pretrain.py is the main training entrypoint.
The default configuration is intentionally modest:
| Setting | Default |
|---|---|
context_length |
128 |
d_model |
256 |
num_layers |
4 |
num_heads |
8 |
batch_size |
16 |
optimizer |
adamw |
tie_embeddings |
True |
scale_embeddings |
True |
use_qk_norm |
True |
zero_init_residual |
True |
uv run scripts/pretrain.py \
--dataset tinystories \
--context-length 256 \
--d-model 384 \
--num-layers 6 \
--num-heads 8 \
--batch-size 8 \
--gradient-accumulation-steps 8 \
--max-iters 5000 \
--learning-rate 3e-4 \
--device cudaYou can stage context length, batch size, attention window size, and MTP weights over training.
uv run scripts/pretrain.py \
--dataset tinystories \
--context-length 256 \
--context-stage 1000:512 \
--context-stage 3000:1024 \
--training-stage start=1000,window=256 \
--training-stage start=3000,window=full,mtp=1.0|0.25 \
--batch-size 16 \
--batch-size-stage 3000:8 \
--max-iters 5000 \
--rope-scaling-type yarn \
--rope-scaling-factor 4.0Notes:
window=<int>enables sliding-window attention.window=fullexplicitly restores full attention.--context-stageand--batch-size-stageare shorthand schedules.--training-stageis the general mechanism and can change:- batch size
- context length
- attention window
- MTP weights
uv run scripts/pretrain.py \
--dataset tinystories \
--num-experts 8 \
--num-experts-per-token 2 \
--num-shared-experts 1 \
--moe-layers all \
--moe-capacity-factor 1.25 \
--moe-aux-loss-coef 0.01 \
--moe-z-loss-coef 0.001 \
--moe-router-balance-strategy aux_lossuv run scripts/pretrain.py \
--dataset tinystories \
--optimizer muon \
--cautious-weight-decay \
--muon-momentum 0.95 \
--muon-ns-steps 5python3 -m torch.distributed.run \
--nproc_per_node 2 \
scripts/pretrain.py \
--dataset tinystories \
--device cudaOn non-CUDA distributed environments, the script falls back to CPU instead of trying unsupported device backends.
--context-length--d-model--d-ff--num-layers--num-heads--num-kv-heads
--rope-theta--rope-scaling-type--rope-scaling-factor--rope-base-context-length--attention-window-size--attention-global-tokens--no-qk-norm--attn-logit-softcap--final-logit-softcap
--num-experts--num-experts-per-token--num-shared-experts--moe-layers--moe-capacity-factor--moe-router-balance-strategy
--max-document-length--document-boundary-sampling-probability--force-rebuild-cache
--optimizer--learning-rate--weight-decay--gradient-accumulation-steps--cautious-weight-decay--no-scale-lr-with-batch-size--no-schedule-weight-decay-with-lr
--device--dtype--compile--pin-memory--no-tf32--resume-from
scripts/sample.py supports:
- greedy decoding
- multinomial sampling
- top-k sampling
- top-p / nucleus sampling
- cache-backed generation
- multi-sample generation
- optional no-cache path for debugging and parity checks
uv run scripts/sample.py \
--run-dir out/pretrain/tinystories \
--prompt "The dragon" \
--strategy greedyuv run scripts/sample.py \
--run-dir out/pretrain/tinystories \
--prompt "The dragon" \
--strategy top_k \
--top-k 50 \
--temperature 0.8 \
--max-new-tokens 128uv run scripts/sample.py \
--run-dir out/pretrain/tinystories \
--prompt "The dragon" \
--strategy top_p \
--top-p 0.9 \
--num-samples 4 \
--no-kv-cacheNotes:
- empty prompts are supported if the tokenizer exposes a special token that can be used as a BOS seed
--num-samplesworks for both cache-backed and no-cache generation paths- the script rebuilds the saved base model configuration from
run_config.json
scripts/eval.py evaluates a saved run directory or a specific checkpoint.
uv run scripts/eval.py \
--run-dir out/pretrain/tinystories \
--eval-iters 20The evaluation output includes:
- train loss
- valid loss
- train / valid CE loss
- train / valid MoE auxiliary loss
- train / valid MoE z-loss
valid_bits_per_byte
valid_bits_per_byte is computed from CE loss against the raw validation text units actually evaluated. It is not polluted by MoE auxiliary loss terms or synthetic document-EoS cache tokens.
The evaluation script is also read-only with respect to saved run configuration.
The dense model lives in src/nano_transformer/model.py.
TransformerLM currently supports:
- merged QKV projection
- grouped-query attention via
num_kv_heads - RoPE with optional scaling (
linear,yarn) - optional QK-Norm
- attention-logit softcapping
- final-logit softcapping
- zero-init residual branch outputs
- optional tied embeddings
- optional multi-token prediction heads
- runtime-updatable attention window size
Supporting building blocks include:
EmbeddingLinearRMSNormSwiGLUScaledDotProductAttentionMultiHeadSelfAttentionTransformerBlock
The attention stack is designed to be modern but still inspectable.
Implemented pieces:
- merged QKV projection to reduce projection overhead
- grouped-query attention to reduce KV cost and inference memory
- RoPE caches built on the correct device
- QK-Norm for stabilizing query/key magnitudes
- SDPA fast path when available
- explicit fallback path for correctness and debuggability
- sliding-window attention
- configurable global prefix tokens inside the local mask
- document masking support for multi-document training batches
The inference path lives in src/nano_transformer/inference/.
Key pieces:
Supported workflows:
- prefill once, decode incrementally
- branch cache state for multi-sample generation
- decode from a prompt batch or a single prompt
- validate cache shapes and cache lengths explicitly
High-level helpers:
generateprefilldecode_stepsample_next_tokenallocate_kv_cachesbranch_kv_caches
The MoE implementation lives in src/nano_transformer/moe.py.
It is fully vectorized in the model forward path and includes:
- top-k token-choice routing
- normalized routing weights
- shared experts
- capacity-based dropping
- routed-token accounting
- auxiliary load-balancing loss
- router z-loss
- optional bias-based router balancing
- synchronized routing stats under distributed training
MoE layers can be enabled selectively or across the full stack.
The data path lives in src/nano_transformer/training/data.py.
Important behaviors:
- tokenize raw text files into reusable caches
- store cache metadata so tokenizer changes invalidate stale caches
- optionally emit and reuse document-id arrays
- prefix documents with an EoS/BOS-style special token when available
- prefer document-boundary batch starts
- mask cross-document targets in loss computation
- cap document length before batching
This gives the training loop a more realistic long-context and multi-document foundation than naive flat-token slicing.
The training stack lives in src/nano_transformer/training/.
Notable pieces:
The trainer supports:
- gradient accumulation
- AMP / autocast
- TF32
- optional
torch.compile - gradient clipping
- global grad-norm logging
- loss-spike and grad-spike alerts
- effective-batch-size-aware LR scaling
- LR-tied weight decay scheduling
- staged training schedules
- checkpoint save/load and resume
- best-checkpoint tracking
- DDP-aware metric reduction
The optimizer layer supports:
- AdamW
- Muon
- hybrid Muon+AdamW parameter grouping
- cautious weight decay
- explicit parameter-policy metadata for grouping decisions
First-class schedule objects:
TrainingStageTrainingSchedule
These can control:
- batch size
- context length
- attention window size
- MTP weights / horizon
from nano_transformer import Tokenizer
tokenizer = Tokenizer.from_files(
vocab_filepath="data/train-bpe-tinystories-vocab.json",
merges_filepath="data/train-bpe-tinystories-merges.txt",
special_tokens=["<|endoftext|>"],
)
token_ids = tokenizer.encode("hello world")
text = tokenizer.decode(token_ids)import torch
from nano_transformer import TransformerLM, generate
model = TransformerLM(
vocab_size=50257,
context_length=128,
d_model=256,
num_layers=4,
num_heads=8,
d_ff=1024,
rope_theta=10000.0,
tie_embeddings=True,
)
prompt = torch.tensor([[1, 2, 3]], dtype=torch.long)
output = generate(model, prompt, max_new_tokens=16, strategy="greedy")from nano_transformer import allocate_kv_caches, prefill, decode_step
kv_caches = allocate_kv_caches(model, batch_size=1, total_seq_len=64)
logits, kv_caches = prefill(model, prompt, kv_caches=kv_caches)
logits, kv_caches = decode_step(model, output[:, -1:], kv_caches)from pathlib import Path
from nano_transformer import TransformerLM
from nano_transformer.training import Trainer, TrainerConfig, prepare_pretraining_data
train_dataset, valid_dataset, tokenizer = prepare_pretraining_data(
train_text_path=Path("data/TinyStoriesV2-GPT4-train.txt"),
valid_text_path=Path("data/TinyStoriesV2-GPT4-valid.txt"),
vocab_path=Path("data/train-bpe-tinystories-vocab.json"),
merges_path=Path("data/train-bpe-tinystories-merges.txt"),
train_cache_path=Path("data/tinystories-train-tokens.npy"),
valid_cache_path=Path("data/tinystories-valid-tokens.npy"),
)
model = TransformerLM(
vocab_size=len(tokenizer.vocab),
context_length=128,
d_model=256,
num_layers=4,
num_heads=8,
d_ff=1024,
rope_theta=10000.0,
)
config = TrainerConfig(
batch_size=16,
context_length=128,
max_iters=1000,
device="cpu",
output_dir="out/example",
)
trainer = Trainer(model, train_dataset, valid_dataset, config)
trainer.train()The package root lazily re-exports the most useful entry points from src/nano_transformer/__init__.py.
Most useful top-level imports:
- model pieces:
TransformerLMTransformerBlockMultiHeadSelfAttentionSwiGLURMSNorm
- inference:
generateprefilldecode_stepKVCache
- tokenization:
Tokenizertrain_bpe
- training:
TrainerTrainerConfigprepare_pretraining_dataprepare_named_pretraining_datacreate_optimizer
- optimization helpers:
AdamWMuonMuonAdamWsplit_muon_params
For lower-level extension work, the training package also exports:
TrainingStageTrainingScheduleParameterPolicybuild_parameter_policies
- model core:
src/nano_transformer/model.py - MoE:
src/nano_transformer/moe.py - tokenizer training:
src/nano_transformer/tokenization/bpe.py - tokenizer runtime:
src/nano_transformer/tokenization/tokenizer.py - data pipeline:
src/nano_transformer/training/data.py - optimizer stack:
src/nano_transformer/training/optim.py - trainer:
src/nano_transformer/training/trainer.py - inference + KV cache:
src/nano_transformer/inference/decoding.py,src/nano_transformer/inference/kv_cache.py
uv run pytestuv run pytest tests/test_model.py
uv run pytest tests/test_inference.py
uv run pytest tests/test_moe.py
uv run pytest tests/test_optimizer.py
uv run pytest tests/test_trainer.py
uv run pytest tests/test_scripts.pyThe test suite includes:
- model-level correctness and snapshot tests
- tokenizer and BPE behavior
- inference parity and KV-cache validation
- MoE routing and distributed-stat behavior
- optimizer grouping and update behavior
- checkpoint serialization
- script-level end-to-end checks for pretraining, evaluation, and sampling
This repository aims to occupy a useful middle ground:
- lower-level and more educational than a production LLM framework
- much more complete than a single-file toy GPT
- explicit enough that architectural and systems decisions are easy to inspect
- practical enough to support real experiments in dense training, MoE, and inference
This repo is intentionally not trying to be:
- a giant training platform with deep cluster orchestration
- a benchmark-maxing, hardware-specific kernel zoo
- a post-training / RLHF framework
- a one-file tutorial at the expense of clarity once the system grows
If you want a compact research codebase that still supports modern attention variants, MoE, KV caching, staged training, and end-to-end workflows, this repo is designed for that use case.