Skip to content

taivu1998/Nano-Transformer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Nano-Transformer

nano-transformer is a from-scratch LLM codebase that starts with core Transformer primitives and grows into a compact but genuinely end-to-end training stack.

It is meant for people who want more than a toy GPT file, but less ceremony than a large production framework. The repo includes tokenizer training, dense and MoE model variants, modern attention features, staged pretraining, checkpointing, evaluation, and sampling in a layout that is explicit enough to study and extend comfortably.

Why This Repo Exists

  • Build the model stack directly in PyTorch instead of hiding it behind a large framework.
  • Keep the code organized enough for real experimentation: dense baselines, GQA, RoPE scaling, KV cache, MoE, schedules, and optimizer variants all live in a coherent package.
  • Preserve readability. The code is designed to be inspectable, testable, and modifiable without turning into a single monolithic training script.
  • Cover the full workflow:
    • train or load a BPE tokenizer
    • prepare token caches
    • pretrain a model
    • resume from checkpoints
    • evaluate
    • sample text

Feature Snapshot

Area Implemented
Tokenization GPT-2-style byte-level BPE, tokenizer loading, special-token handling, tokenizer training utilities
Core model Embedding, RMSNorm, SwiGLU, merged QKV, RoPE, QK-Norm, grouped-query attention, optional tied embeddings
Attention variants SDPA fast path, explicit fallback path, sliding-window attention, global prefix tokens, RoPE scaling (linear, yarn)
Inference Preallocated KV cache, prefill, decode_step, cache branching, greedy / sample / top-k / top-p decoding
MoE Vectorized token-choice routing, shared experts, capacity control, dropped-token accounting, aux loss, z-loss, bias balancing
Training Document-aware batching, EoS-prefixed document tokenization, max-document-length truncation, gradient accumulation, staged schedules
Optimizers AdamW, Muon, hybrid Muon+AdamW grouping, cautious weight decay, LR-tied weight decay scheduling
Runtime TF32 support, autocast, optional torch.compile, DDP support, checkpoint save/load, resume support
Validation Extensive unit tests, snapshot tests, script-level coverage for pretrain / eval / sample flows

Current Validation

Current local validation for this repository:

  • uv run pytest -> 117 passed, 2 skipped
  • script flows are covered in tests/test_scripts.py
  • the skipped tests are memory-limit checks that depend on local resource limits

End-to-End Workflow

flowchart LR
    A["Raw text corpora"] --> B["BPE tokenizer<br/>vocab + merges"]
    B --> C["Tokenized dataset caches<br/>+ document ids"]
    C --> D["TransformerLM / MoE model"]
    D --> E["Trainer<br/>staged schedules + optimizer"]
    E --> F["Checkpoints<br/>latest.pt / best.pt"]
    F --> G["eval.py"]
    F --> H["sample.py"]
Loading

Repository Layout

src/nano_transformer/
  __init__.py
  model.py
  moe.py
  inference/
    decoding.py
    kv_cache.py
  tokenization/
    tokenizer.py
    bpe.py
    gpt2.py
  training/
    checkpointing.py
    data.py
    optim.py
    trainer.py

scripts/
  pretrain.py
  sample.py
  eval.py
  train_bpe.py

tests/
  test_model.py
  test_inference.py
  test_moe.py
  test_optimizer.py
  test_trainer.py
  test_scripts.py
  ...

Setup

This repository uses uv and is packaged with a src/ layout.

Requirements

  • Python >= 3.11
  • uv
  • PyTorch is resolved from pyproject.toml based on platform

Recommended Install

brew install uv
uv sync

If you prefer not to sync eagerly, uv run ... will also resolve and run commands in the project environment.

For editable package-style imports outside uv run, you can also do:

uv pip install -e .

Quick Sanity Check

uv run pytest

What Is Checked In vs. Expected Locally

This distinction matters.

Checked into git

Expected locally for real training

The large raw text corpora are intentionally not committed. By default, the named dataset shortcuts expect local files such as:

  • data/TinyStoriesV2-GPT4-train.txt
  • data/TinyStoriesV2-GPT4-valid.txt
  • data/owt_train.txt
  • data/owt_valid.txt

Generated artifacts are also local:

  • token caches (*.npy)
  • document-id caches
  • checkpoints
  • run directories under out/

If you do not want to use the named dataset shortcuts, you can point the training script directly at your own files with --train-text-path, --valid-text-path, --vocab-path, and --merges-path.

Data Setup

Supported named datasets

scripts/pretrain.py accepts:

  • tinystories
  • openwebtext
  • owt

Download expected raw files

Example:

mkdir -p data
cd data

curl -L -o TinyStoriesV2-GPT4-train.txt \
  https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStoriesV2-GPT4-train.txt
curl -L -o TinyStoriesV2-GPT4-valid.txt \
  https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStoriesV2-GPT4-valid.txt

curl -L -o owt_train.txt.gz \
  https://huggingface.co/datasets/stanford-cs336/owt-sample/resolve/main/owt_train.txt.gz
gunzip -f owt_train.txt.gz

curl -L -o owt_valid.txt.gz \
  https://huggingface.co/datasets/stanford-cs336/owt-sample/resolve/main/owt_valid.txt.gz
gunzip -f owt_valid.txt.gz

Data pipeline behavior

The training data path is more than "flatten token stream and slice windows":

  • documents are tokenized with a document-EoS prefix when possible
  • token caches are fingerprinted so tokenizer changes invalidate stale caches
  • optional document-id arrays are stored alongside token caches
  • batch starts can prefer document boundaries
  • cross-document targets can be masked out
  • documents can be truncated to --max-document-length

These behaviors live primarily in src/nano_transformer/training/data.py.

Quick Start

1. Run tests

uv run pytest

2. Train a small dense model

uv run scripts/pretrain.py \
  --dataset tinystories \
  --context-length 128 \
  --d-model 256 \
  --num-layers 4 \
  --num-heads 8 \
  --batch-size 16 \
  --gradient-accumulation-steps 4 \
  --max-iters 1000 \
  --device cpu

3. Sample from the checkpoint

uv run scripts/sample.py \
  --run-dir out/pretrain/tinystories \
  --prompt "Once upon a time" \
  --max-new-tokens 64 \
  --strategy top_k \
  --top-k 40

4. Evaluate the run

uv run scripts/eval.py \
  --run-dir out/pretrain/tinystories \
  --eval-iters 20

Run directory contents

Training writes a run directory such as:

out/pretrain/tinystories/
  latest.pt
  best.pt
  run_config.json
  trainer_config.json
  metrics.jsonl

What these files are for:

  • latest.pt: latest checkpoint
  • best.pt: best validation checkpoint tracked so far
  • run_config.json: dataset paths, tokenizer paths, special tokens, and model config used for the run
  • trainer_config.json: trainer/runtime configuration for the run
  • metrics.jsonl: per-step training and evaluation metrics such as loss, CE loss, tokens/sec, grad norm, active schedule state, and MoE counters

Training a Tokenizer

There are two practical ways to train or refresh tokenizer artifacts.

Library API

This is the most flexible path:

from nano_transformer import train_bpe

vocab, merges = train_bpe(
    input_path="data/TinyStoriesV2-GPT4-train.txt",
    vocab_size=10000,
    special_tokens=["<|endoftext|>"],
)

Helper script

scripts/train_bpe.py contains small helper functions for serializing tokenizer artifacts for the bundled datasets. It is intentionally lightweight, and it is not a full CLI wrapper yet.

uv run scripts/train_bpe.py

If you want custom tokenizer-training behavior, prefer the library API or edit the helper functions in scripts/train_bpe.py.

Pretraining Guide

scripts/pretrain.py is the main training entrypoint.

Default baseline

The default configuration is intentionally modest:

Setting Default
context_length 128
d_model 256
num_layers 4
num_heads 8
batch_size 16
optimizer adamw
tie_embeddings True
scale_embeddings True
use_qk_norm True
zero_init_residual True

Common dense run

uv run scripts/pretrain.py \
  --dataset tinystories \
  --context-length 256 \
  --d-model 384 \
  --num-layers 6 \
  --num-heads 8 \
  --batch-size 8 \
  --gradient-accumulation-steps 8 \
  --max-iters 5000 \
  --learning-rate 3e-4 \
  --device cuda

Long-context staged run

You can stage context length, batch size, attention window size, and MTP weights over training.

uv run scripts/pretrain.py \
  --dataset tinystories \
  --context-length 256 \
  --context-stage 1000:512 \
  --context-stage 3000:1024 \
  --training-stage start=1000,window=256 \
  --training-stage start=3000,window=full,mtp=1.0|0.25 \
  --batch-size 16 \
  --batch-size-stage 3000:8 \
  --max-iters 5000 \
  --rope-scaling-type yarn \
  --rope-scaling-factor 4.0

Notes:

  • window=<int> enables sliding-window attention.
  • window=full explicitly restores full attention.
  • --context-stage and --batch-size-stage are shorthand schedules.
  • --training-stage is the general mechanism and can change:
    • batch size
    • context length
    • attention window
    • MTP weights

MoE run

uv run scripts/pretrain.py \
  --dataset tinystories \
  --num-experts 8 \
  --num-experts-per-token 2 \
  --num-shared-experts 1 \
  --moe-layers all \
  --moe-capacity-factor 1.25 \
  --moe-aux-loss-coef 0.01 \
  --moe-z-loss-coef 0.001 \
  --moe-router-balance-strategy aux_loss

Muon run

uv run scripts/pretrain.py \
  --dataset tinystories \
  --optimizer muon \
  --cautious-weight-decay \
  --muon-momentum 0.95 \
  --muon-ns-steps 5

Distributed run

python3 -m torch.distributed.run \
  --nproc_per_node 2 \
  scripts/pretrain.py \
  --dataset tinystories \
  --device cuda

On non-CUDA distributed environments, the script falls back to CPU instead of trying unsupported device backends.

Pretraining flags that matter most

Model shape

  • --context-length
  • --d-model
  • --d-ff
  • --num-layers
  • --num-heads
  • --num-kv-heads

Attention and long-context behavior

  • --rope-theta
  • --rope-scaling-type
  • --rope-scaling-factor
  • --rope-base-context-length
  • --attention-window-size
  • --attention-global-tokens
  • --no-qk-norm
  • --attn-logit-softcap
  • --final-logit-softcap

MoE

  • --num-experts
  • --num-experts-per-token
  • --num-shared-experts
  • --moe-layers
  • --moe-capacity-factor
  • --moe-router-balance-strategy

Data behavior

  • --max-document-length
  • --document-boundary-sampling-probability
  • --force-rebuild-cache

Optimization

  • --optimizer
  • --learning-rate
  • --weight-decay
  • --gradient-accumulation-steps
  • --cautious-weight-decay
  • --no-scale-lr-with-batch-size
  • --no-schedule-weight-decay-with-lr

Runtime

  • --device
  • --dtype
  • --compile
  • --pin-memory
  • --no-tf32
  • --resume-from

Sampling Guide

scripts/sample.py supports:

  • greedy decoding
  • multinomial sampling
  • top-k sampling
  • top-p / nucleus sampling
  • cache-backed generation
  • multi-sample generation
  • optional no-cache path for debugging and parity checks

Greedy

uv run scripts/sample.py \
  --run-dir out/pretrain/tinystories \
  --prompt "The dragon" \
  --strategy greedy

Top-k

uv run scripts/sample.py \
  --run-dir out/pretrain/tinystories \
  --prompt "The dragon" \
  --strategy top_k \
  --top-k 50 \
  --temperature 0.8 \
  --max-new-tokens 128

Top-p without KV cache

uv run scripts/sample.py \
  --run-dir out/pretrain/tinystories \
  --prompt "The dragon" \
  --strategy top_p \
  --top-p 0.9 \
  --num-samples 4 \
  --no-kv-cache

Notes:

  • empty prompts are supported if the tokenizer exposes a special token that can be used as a BOS seed
  • --num-samples works for both cache-backed and no-cache generation paths
  • the script rebuilds the saved base model configuration from run_config.json

Evaluation Guide

scripts/eval.py evaluates a saved run directory or a specific checkpoint.

uv run scripts/eval.py \
  --run-dir out/pretrain/tinystories \
  --eval-iters 20

The evaluation output includes:

  • train loss
  • valid loss
  • train / valid CE loss
  • train / valid MoE auxiliary loss
  • train / valid MoE z-loss
  • valid_bits_per_byte

valid_bits_per_byte is computed from CE loss against the raw validation text units actually evaluated. It is not polluted by MoE auxiliary loss terms or synthetic document-EoS cache tokens.

The evaluation script is also read-only with respect to saved run configuration.

Architecture Walkthrough

Model

The dense model lives in src/nano_transformer/model.py.

TransformerLM currently supports:

  • merged QKV projection
  • grouped-query attention via num_kv_heads
  • RoPE with optional scaling (linear, yarn)
  • optional QK-Norm
  • attention-logit softcapping
  • final-logit softcapping
  • zero-init residual branch outputs
  • optional tied embeddings
  • optional multi-token prediction heads
  • runtime-updatable attention window size

Supporting building blocks include:

  • Embedding
  • Linear
  • RMSNorm
  • SwiGLU
  • ScaledDotProductAttention
  • MultiHeadSelfAttention
  • TransformerBlock

Attention

The attention stack is designed to be modern but still inspectable.

Implemented pieces:

  • merged QKV projection to reduce projection overhead
  • grouped-query attention to reduce KV cost and inference memory
  • RoPE caches built on the correct device
  • QK-Norm for stabilizing query/key magnitudes
  • SDPA fast path when available
  • explicit fallback path for correctness and debuggability
  • sliding-window attention
  • configurable global prefix tokens inside the local mask
  • document masking support for multi-document training batches

Inference and KV Cache

The inference path lives in src/nano_transformer/inference/.

Key pieces:

Supported workflows:

  • prefill once, decode incrementally
  • branch cache state for multi-sample generation
  • decode from a prompt batch or a single prompt
  • validate cache shapes and cache lengths explicitly

High-level helpers:

  • generate
  • prefill
  • decode_step
  • sample_next_token
  • allocate_kv_caches
  • branch_kv_caches

MoE

The MoE implementation lives in src/nano_transformer/moe.py.

It is fully vectorized in the model forward path and includes:

  • top-k token-choice routing
  • normalized routing weights
  • shared experts
  • capacity-based dropping
  • routed-token accounting
  • auxiliary load-balancing loss
  • router z-loss
  • optional bias-based router balancing
  • synchronized routing stats under distributed training

MoE layers can be enabled selectively or across the full stack.

Data Pipeline

The data path lives in src/nano_transformer/training/data.py.

Important behaviors:

  • tokenize raw text files into reusable caches
  • store cache metadata so tokenizer changes invalidate stale caches
  • optionally emit and reuse document-id arrays
  • prefix documents with an EoS/BOS-style special token when available
  • prefer document-boundary batch starts
  • mask cross-document targets in loss computation
  • cap document length before batching

This gives the training loop a more realistic long-context and multi-document foundation than naive flat-token slicing.

Optimization and Training

The training stack lives in src/nano_transformer/training/.

Notable pieces:

The trainer supports:

  • gradient accumulation
  • AMP / autocast
  • TF32
  • optional torch.compile
  • gradient clipping
  • global grad-norm logging
  • loss-spike and grad-spike alerts
  • effective-batch-size-aware LR scaling
  • LR-tied weight decay scheduling
  • staged training schedules
  • checkpoint save/load and resume
  • best-checkpoint tracking
  • DDP-aware metric reduction

The optimizer layer supports:

  • AdamW
  • Muon
  • hybrid Muon+AdamW parameter grouping
  • cautious weight decay
  • explicit parameter-policy metadata for grouping decisions

First-class schedule objects:

  • TrainingStage
  • TrainingSchedule

These can control:

  • batch size
  • context length
  • attention window size
  • MTP weights / horizon

Python API Examples

Load a tokenizer and encode text

from nano_transformer import Tokenizer

tokenizer = Tokenizer.from_files(
    vocab_filepath="data/train-bpe-tinystories-vocab.json",
    merges_filepath="data/train-bpe-tinystories-merges.txt",
    special_tokens=["<|endoftext|>"],
)

token_ids = tokenizer.encode("hello world")
text = tokenizer.decode(token_ids)

Build a model and generate text

import torch
from nano_transformer import TransformerLM, generate

model = TransformerLM(
    vocab_size=50257,
    context_length=128,
    d_model=256,
    num_layers=4,
    num_heads=8,
    d_ff=1024,
    rope_theta=10000.0,
    tie_embeddings=True,
)

prompt = torch.tensor([[1, 2, 3]], dtype=torch.long)
output = generate(model, prompt, max_new_tokens=16, strategy="greedy")

Use KV-cache helpers directly

from nano_transformer import allocate_kv_caches, prefill, decode_step

kv_caches = allocate_kv_caches(model, batch_size=1, total_seq_len=64)
logits, kv_caches = prefill(model, prompt, kv_caches=kv_caches)
logits, kv_caches = decode_step(model, output[:, -1:], kv_caches)

Prepare data and build a trainer

from pathlib import Path

from nano_transformer import TransformerLM
from nano_transformer.training import Trainer, TrainerConfig, prepare_pretraining_data

train_dataset, valid_dataset, tokenizer = prepare_pretraining_data(
    train_text_path=Path("data/TinyStoriesV2-GPT4-train.txt"),
    valid_text_path=Path("data/TinyStoriesV2-GPT4-valid.txt"),
    vocab_path=Path("data/train-bpe-tinystories-vocab.json"),
    merges_path=Path("data/train-bpe-tinystories-merges.txt"),
    train_cache_path=Path("data/tinystories-train-tokens.npy"),
    valid_cache_path=Path("data/tinystories-valid-tokens.npy"),
)

model = TransformerLM(
    vocab_size=len(tokenizer.vocab),
    context_length=128,
    d_model=256,
    num_layers=4,
    num_heads=8,
    d_ff=1024,
    rope_theta=10000.0,
)

config = TrainerConfig(
    batch_size=16,
    context_length=128,
    max_iters=1000,
    device="cpu",
    output_dir="out/example",
)

trainer = Trainer(model, train_dataset, valid_dataset, config)
trainer.train()

Public API Surface

The package root lazily re-exports the most useful entry points from src/nano_transformer/__init__.py.

Most useful top-level imports:

  • model pieces:
    • TransformerLM
    • TransformerBlock
    • MultiHeadSelfAttention
    • SwiGLU
    • RMSNorm
  • inference:
    • generate
    • prefill
    • decode_step
    • KVCache
  • tokenization:
    • Tokenizer
    • train_bpe
  • training:
    • Trainer
    • TrainerConfig
    • prepare_pretraining_data
    • prepare_named_pretraining_data
    • create_optimizer
  • optimization helpers:
    • AdamW
    • Muon
    • MuonAdamW
    • split_muon_params

For lower-level extension work, the training package also exports:

  • TrainingStage
  • TrainingSchedule
  • ParameterPolicy
  • build_parameter_policies

Suggested Entry Points for Reading the Code

Testing and Development

Full suite

uv run pytest

Targeted suites

uv run pytest tests/test_model.py
uv run pytest tests/test_inference.py
uv run pytest tests/test_moe.py
uv run pytest tests/test_optimizer.py
uv run pytest tests/test_trainer.py
uv run pytest tests/test_scripts.py

What is covered

The test suite includes:

  • model-level correctness and snapshot tests
  • tokenizer and BPE behavior
  • inference parity and KV-cache validation
  • MoE routing and distributed-stat behavior
  • optimizer grouping and update behavior
  • checkpoint serialization
  • script-level end-to-end checks for pretraining, evaluation, and sampling

Philosophy

This repository aims to occupy a useful middle ground:

  • lower-level and more educational than a production LLM framework
  • much more complete than a single-file toy GPT
  • explicit enough that architectural and systems decisions are easy to inspect
  • practical enough to support real experiments in dense training, MoE, and inference

Non-Goals

This repo is intentionally not trying to be:

  • a giant training platform with deep cluster orchestration
  • a benchmark-maxing, hardware-specific kernel zoo
  • a post-training / RLHF framework
  • a one-file tutorial at the expense of clarity once the system grows

If you want a compact research codebase that still supports modern attention variants, MoE, KV caching, staged training, and end-to-end workflows, this repo is designed for that use case.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages