Implementation of the Tiny Recursive Model (TRM) from the paper "Less is More: Recursive Reasoning with Tiny Networks" (arXiv:2510.04871v1) by Alexia Jolicoeur-Martineau.
- Paper: arXiv:2510.04871v1
- Title: Less is More: Recursive Reasoning with Tiny Networks
- Author: Alexia Jolicoeur-Martineau
- Focus: Sudoku-Extreme dataset (9x9 grids)
This is a faithful implementation of the TRM-MLP variant for Sudoku puzzle solving. The model uses a tiny 2-layer recursive network that achieves state-of-the-art performance on Sudoku-Extreme puzzles through deep recursive reasoning.
Key Features:
- Single 2-layer network shared for all recursions
- Deep recursion with T=3 cycles and n=6 latent steps per cycle
- MLP-Mixer architecture (no self-attention)
- Exponential Moving Average (EMA) for training stability
- Adaptive Computation Time (ACT) for early stopping
Target Performance:
- Test Accuracy: 87.4%
- Parameters: 5M
- Effective Depth: 42 layers (T × (n+1) × num_layers = 3 × 7 × 2)
- Python 3.9-3.12 (3.12 recommended)
- GPU recommended: CUDA-capable GPU or Apple Silicon with MPS
- ~2GB GPU memory for training with batch size 768
- Note: PyTorch 2.2.2 is the latest version with macOS x86_64 (Intel Mac) support
- uv package manager
- Clone or download this repository
- Install uv if you haven't already:
curl -LsSf https://astral.sh/uv/install.sh | sh- Install dependencies:
uv syncOr install in your current environment:
uv pip install -e .Place your Sudoku dataset files in the data/ directory:
trm/
├── data/
│ ├── train.txt # Training puzzles (1K samples)
│ └── test.txt # Test puzzles (423K samples)
Data Format:
- Each puzzle should be a 9x9 grid flattened to 81 values
- Values: 0 (empty cell) or 1-9 (filled cells)
- Input: partially filled puzzles
- Output: complete solutions
Run the training script:
uv run python train.pyTraining Hyperparameters (hardcoded in train.py):
- Learning rate: 1e-4
- Weight decay: 1.0
- Batch size: 768
- Training epochs: 60,000
- Data augmentation: 1000 shuffling augmentations per training sample
- Max supervision steps: 16
- EMA decay: 0.999
Training will save checkpoints to checkpoints/ directory.
Run the evaluation script:
uv run python eval.py \
--checkpoint checkpoints/checkpoint_ema_epoch_100.pt \
--data-dir data/sudoku-extreme-1k-aug-1000Options:
--checkpoint: Path to checkpoint file (use EMA checkpoint for best results)--data-dir: Path to dataset directory--batch-size: Batch size for evaluation (default: 768)--device: Device to use (default: cuda if available, else cpu)--seed: Random seed for reproducibility (default: 42)
The evaluation script reports:
- Per-puzzle accuracy: Percentage of puzzles solved correctly (all 81 cells)
- Per-cell accuracy: Percentage of individual cells predicted correctly
- Target accuracy from paper: 87.4%
trm/
├── model.py # TRM architecture (embeddings, layers, network)
├── train.py # Training script with deep supervision
├── eval.py # Evaluation script
├── dataset.py # Sudoku dataset loader and augmentation
├── losses.py # Loss functions (prediction + halting)
├── pyproject.toml # Project configuration and dependencies
├── README.md # This file
├── IMPLEMENTATION_PLAN.md # Detailed implementation plan
├── 2510.04871v1.pdf # Paper
├── data/ # Data directory (gitignored)
└── checkpoints/ # Model checkpoints (gitignored)
- Hidden size: 512
- Number of layers: 2
- Context length: 81 (9×9 grid)
- Vocabulary size: 10 (tokens 0-9)
- Recursions per cycle: n=6
- Deep recursion cycles: T=3
- Activation: SwiGLU
- Normalization: RMSNorm
- Position encoding: Rotary embeddings
This implementation strictly follows the paper specifications:
- All hyperparameters are hardcoded as specified in the paper
- No configuration files or additional features
- Focus on TRM-MLP variant only (no self-attention)
- Flat project structure for simplicity
For implementation details, see:
- Section 2.1: Base architecture components
- Section 4: TRM improvements over HRM
- Algorithm 3 (Figure 3, Page 5): Training pseudocode
- Table 1 (Page 5): Performance comparison
- Page 11: Complete hyperparameters
- Section 5 (Page 8): Dataset specifications