Skip to content
/ trm Public

Tiny Recursive Model (TRM) from the paper "Less is More: Recursive Reasoning with Tiny Networks"

Notifications You must be signed in to change notification settings

jakeret/trm

Repository files navigation

TRM Implementation for Sudoku-Extreme

Implementation of the Tiny Recursive Model (TRM) from the paper "Less is More: Recursive Reasoning with Tiny Networks" (arXiv:2510.04871v1) by Alexia Jolicoeur-Martineau.

Paper Reference

  • Paper: arXiv:2510.04871v1
  • Title: Less is More: Recursive Reasoning with Tiny Networks
  • Author: Alexia Jolicoeur-Martineau
  • Focus: Sudoku-Extreme dataset (9x9 grids)

Project Description

This is a faithful implementation of the TRM-MLP variant for Sudoku puzzle solving. The model uses a tiny 2-layer recursive network that achieves state-of-the-art performance on Sudoku-Extreme puzzles through deep recursive reasoning.

Key Features:

  • Single 2-layer network shared for all recursions
  • Deep recursion with T=3 cycles and n=6 latent steps per cycle
  • MLP-Mixer architecture (no self-attention)
  • Exponential Moving Average (EMA) for training stability
  • Adaptive Computation Time (ACT) for early stopping

Target Performance:

  • Test Accuracy: 87.4%
  • Parameters: 5M
  • Effective Depth: 42 layers (T × (n+1) × num_layers = 3 × 7 × 2)

Prerequisites

  • Python 3.9-3.12 (3.12 recommended)
  • GPU recommended: CUDA-capable GPU or Apple Silicon with MPS
  • ~2GB GPU memory for training with batch size 768
  • Note: PyTorch 2.2.2 is the latest version with macOS x86_64 (Intel Mac) support
  • uv package manager

Installation

  1. Clone or download this repository
  2. Install uv if you haven't already:
curl -LsSf https://astral.sh/uv/install.sh | sh
  1. Install dependencies:
uv sync

Or install in your current environment:

uv pip install -e .

Data Setup

Place your Sudoku dataset files in the data/ directory:

trm/
├── data/
│   ├── train.txt    # Training puzzles (1K samples)
│   └── test.txt     # Test puzzles (423K samples)

Data Format:

  • Each puzzle should be a 9x9 grid flattened to 81 values
  • Values: 0 (empty cell) or 1-9 (filled cells)
  • Input: partially filled puzzles
  • Output: complete solutions

Usage

Training

Run the training script:

uv run python train.py

Training Hyperparameters (hardcoded in train.py):

  • Learning rate: 1e-4
  • Weight decay: 1.0
  • Batch size: 768
  • Training epochs: 60,000
  • Data augmentation: 1000 shuffling augmentations per training sample
  • Max supervision steps: 16
  • EMA decay: 0.999

Training will save checkpoints to checkpoints/ directory.

Evaluation

Run the evaluation script:

uv run python eval.py \
  --checkpoint checkpoints/checkpoint_ema_epoch_100.pt \
  --data-dir data/sudoku-extreme-1k-aug-1000

Options:

  • --checkpoint: Path to checkpoint file (use EMA checkpoint for best results)
  • --data-dir: Path to dataset directory
  • --batch-size: Batch size for evaluation (default: 768)
  • --device: Device to use (default: cuda if available, else cpu)
  • --seed: Random seed for reproducibility (default: 42)

The evaluation script reports:

  • Per-puzzle accuracy: Percentage of puzzles solved correctly (all 81 cells)
  • Per-cell accuracy: Percentage of individual cells predicted correctly
  • Target accuracy from paper: 87.4%

Project Structure

trm/
├── model.py              # TRM architecture (embeddings, layers, network)
├── train.py              # Training script with deep supervision
├── eval.py               # Evaluation script
├── dataset.py            # Sudoku dataset loader and augmentation
├── losses.py             # Loss functions (prediction + halting)
├── pyproject.toml        # Project configuration and dependencies
├── README.md             # This file
├── IMPLEMENTATION_PLAN.md # Detailed implementation plan
├── 2510.04871v1.pdf      # Paper
├── data/                 # Data directory (gitignored)
└── checkpoints/          # Model checkpoints (gitignored)

Architecture Details

  • Hidden size: 512
  • Number of layers: 2
  • Context length: 81 (9×9 grid)
  • Vocabulary size: 10 (tokens 0-9)
  • Recursions per cycle: n=6
  • Deep recursion cycles: T=3
  • Activation: SwiGLU
  • Normalization: RMSNorm
  • Position encoding: Rotary embeddings

Implementation Notes

This implementation strictly follows the paper specifications:

  • All hyperparameters are hardcoded as specified in the paper
  • No configuration files or additional features
  • Focus on TRM-MLP variant only (no self-attention)
  • Flat project structure for simplicity

References

For implementation details, see:

  • Section 2.1: Base architecture components
  • Section 4: TRM improvements over HRM
  • Algorithm 3 (Figure 3, Page 5): Training pseudocode
  • Table 1 (Page 5): Performance comparison
  • Page 11: Complete hyperparameters
  • Section 5 (Page 8): Dataset specifications

About

Tiny Recursive Model (TRM) from the paper "Less is More: Recursive Reasoning with Tiny Networks"

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages