Skip to content

taivu1998/PRM-Math

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

4 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

PRM-Math: Inference-Time Compute Scaling via Dense Process Supervision

A production-ready implementation of Process Reward Models (PRMs) for mathematical reasoning, featuring state-of-the-art test-time compute scaling techniques. This project enables fine-tuning generative verifiers and deploying advanced inference strategies including Best-of-N Search, PRM-Weighted Voting, and Monte Carlo Tree Search (MCTS) with learned value functions.


Highlights

Benchmark Best Method Accuracy Improvement
GSM8K Majority@16 87.0% +5% over Pass@1
MATH-500 MCTS@20 54.0% +10% over all baselines

Key Result: On competition-level MATH-500 problems, MCTS with PRM value function achieves 54% accuracyβ€”a 10 percentage point improvement over majority voting and PRM reranking. This demonstrates the power of test-time compute scaling on challenging problems.


Key Features

  • Generative PRM Training: Fine-tune decoder-only models as step-level verifiers using the "next-token prediction" paradigm
  • Multiple Inference Strategies: Best-of-N, Majority Voting, PRM Reranking, PRM-Weighted Majority, and MCTS
  • MCTS with Learned Value Function: Tree search using generation logprobs as priors and PRM scores as value estimates
  • Efficient Training: 4-bit QLoRA training via Unsloth with 2x speedup and 60% memory reduction
  • Production Inference: vLLM integration for high-throughput batch scoring
  • Comprehensive Evaluation: Support for GSM8K, MATH, and MATH-500 benchmarks

Why Process Reward Models?

Traditional Outcome Reward Models (ORMs) only judge final answers, missing critical reasoning errors. Process Reward Models provide dense, step-level supervision that:

  1. Catches errors early - Identifies the first incorrect step, not just wrong answers
  2. Enables better search - Provides reliable value estimates for tree search algorithms
  3. Improves sample efficiency - More training signal per problem than sparse outcome rewards
  4. Scales at test-time - More compute at inference = better accuracy (test-time scaling)
ORM: Problem β†’ Solution β†’ Score (0 or 1)
PRM: Problem β†’ Step₁ β†’ Score₁ β†’ Stepβ‚‚ β†’ Scoreβ‚‚ β†’ ... β†’ StepN β†’ ScoreN

Test-Time Compute Scaling

This project implements the key insight from recent research: scaling inference compute can be more efficient than scaling model parameters. We provide three levels of test-time scaling:

Strategy Compute Description
Best-of-N O(N) Generate N solutions, select highest PRM score
Weighted Voting O(N) Combine majority voting with PRM confidence weighting
MCTS O(simulations) Tree search with PRM value function and logprob priors

Architecture

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                    PRM-Math Architecture                         β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚                                                                  β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”       β”‚
β”‚  β”‚  Base Model  β”‚    β”‚   Fine-tune  β”‚    β”‚  PRM Model   β”‚       β”‚
β”‚  β”‚  (Qwen-1.5B) β”‚ ─► β”‚   (QLoRA)    β”‚ ─► β”‚  (Verifier)  β”‚       β”‚
β”‚  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜       β”‚
β”‚         β”‚                                        β”‚               β”‚
β”‚         β–Ό                                        β–Ό               β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”   β”‚
β”‚  β”‚                  Inference Strategies                     β”‚   β”‚
β”‚  β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€   β”‚
β”‚  β”‚  Best-of-N   β”‚   Weighted   β”‚    MCTS      β”‚   Beam      β”‚   β”‚
β”‚  β”‚   Search     β”‚    Voting    β”‚   Search     β”‚   Search    β”‚   β”‚
β”‚  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜   β”‚
β”‚                              β”‚                                   β”‚
β”‚                              β–Ό                                   β”‚
β”‚                    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”                         β”‚
β”‚                    β”‚   Final Answer   β”‚                         β”‚
β”‚                    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜                         β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Technical Details

Component Implementation
Base Model Qwen/Qwen2.5-Math-1.5B-Instruct
Training QLoRA (4-bit) via Unsloth + TRL SFTTrainer
Verification Format Context + Step + <|verify|> β†’ +/-
Score Aggregation Min (Weakest Link) or Product (Cumulative)
MCTS Prior Generation log-probabilities
MCTS Value PRM step-product scores
Dataset Math-Shepherd (step-labeled reasoning traces)

Installation

Requirements

  • Python >= 3.10
  • CUDA-capable GPU (16GB+ VRAM for training, 8GB+ for inference)
  • Linux or WSL2 (Unsloth requirement)

Quick Start

# Clone the repository
git clone https://github.com/yourusername/PRM-Math.git
cd PRM-Math

# Automated setup (recommended)
./setup.sh
source venv/bin/activate

# Or manual installation
python -m venv venv && source venv/bin/activate
pip install torch --index-url https://download.pytorch.org/whl/cu121
pip install -r requirements.txt
pip install "unsloth[cu121-ampere-torch220] @ git+https://github.com/unslothai/unsloth.git"
pip install -e .

Google Colab

# Install dependencies (run in first cell, then restart runtime)
!pip install torch --index-url https://download.pytorch.org/whl/cu121
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install transformers datasets accelerate trl peft bitsandbytes vllm

Usage

1. Training a Process Reward Model

Fine-tune a generative verifier on the Math-Shepherd dataset:

# Train with default configuration
make train

# Or with custom parameters
python scripts/train.py --config configs/default.yaml \
    --training.learning_rate 2e-4 \
    --training.num_train_epochs 1 \
    --data.max_samples 50000

Training outputs:

  • ./checkpoints/ - LoRA adapter checkpoints
  • ./checkpoints/merged_model/ - Merged 16-bit model for inference
  • ./logs/ - Training logs and metrics

2. Inference with Best-of-N Search

# Interactive demo
python scripts/inference.py --model_path checkpoints/merged_model

# Single problem with custom settings
python scripts/inference.py \
    --model_path checkpoints/merged_model \
    --problem "A store sells apples for $2 each. If you buy 5 apples and pay with a $20 bill, how much change do you get?" \
    --n_candidates 16 \
    --temperature 0.8 \
    --aggregation product  # or "min" for weakest-link

3. MCTS Inference (Advanced)

Monte Carlo Tree Search with PRM as the value function:

# Single problem with MCTS
python scripts/mcts.py \
    --model_path checkpoints/merged_model \
    --problem "Find all integer solutions to x^2 - 5x + 6 = 0" \
    --simulations 50

# Evaluate MCTS on GSM8K
python scripts/mcts.py \
    --model_path checkpoints/merged_model \
    --dataset gsm8k \
    --n_problems 100 \
    --simulations 20

4. Comprehensive Evaluation

Evaluate on standard benchmarks with multiple strategies:

# GSM8K evaluation (grade school math)
python scripts/evaluate.py \
    --model_path checkpoints/merged_model \
    --dataset gsm8k \
    --n_candidates 16 \
    --n_problems 500

# MATH-500 evaluation (competition-level)
python scripts/evaluate.py \
    --model_path checkpoints/merged_model \
    --dataset math500 \
    --n_candidates 8 \
    --aggregation product

# Full benchmark suite
make evaluate CHECKPOINT=checkpoints/merged_model DATASET=gsm8k N_CANDIDATES=16
make evaluate CHECKPOINT=checkpoints/merged_model DATASET=math500 N_CANDIDATES=8

Evaluation Strategies

The evaluation script compares four selection strategies:

Strategy Description
Pass@1 Baseline: first generated solution
Majority@N Most common answer among N solutions
PRM Rerank@N Highest PRM-scored solution
PRM-Weighted@N Majority voting weighted by PRM scores

Example output:

============================================================
EVALUATION RESULTS
============================================================
Dataset: gsm8k
N problems: 500
N candidates: 16
Aggregation: product
------------------------------------------------------------
Pass@1:           52.40%
Majority@16:      61.20%
PRM Rerank@16:    67.80%
PRM-Weighted@16:  68.40%
============================================================

PRM Rerank improvement over Pass@1: +15.40%
PRM-Weighted improvement over Pass@1: +16.00%

Project Structure

PRM-Math/
β”œβ”€β”€ configs/
β”‚   └── default.yaml          # Training hyperparameters
β”œβ”€β”€ scripts/
β”‚   β”œβ”€β”€ train.py              # PRM fine-tuning (Unsloth + TRL)
β”‚   β”œβ”€β”€ inference.py          # Best-of-N search with PRM
β”‚   β”œβ”€β”€ evaluate.py           # Benchmark evaluation
β”‚   └── mcts.py               # MCTS with PRM value function
β”œβ”€β”€ src/
β”‚   β”œβ”€β”€ __init__.py
β”‚   β”œβ”€β”€ config_parser.py      # YAML + CLI configuration
β”‚   β”œβ”€β”€ dataset.py            # Math-Shepherd data processing
β”‚   β”œβ”€β”€ model.py              # Model loading and merging
β”‚   └── utils.py              # Utilities (logging, seeding)
β”œβ”€β”€ notebooks/
β”‚   └── PRM_Math_Colab.ipynb  # Interactive Colab notebook
β”œβ”€β”€ checkpoints/              # Saved models
β”œβ”€β”€ logs/                     # Training logs
β”œβ”€β”€ eval_results/             # Evaluation outputs
β”œβ”€β”€ Makefile                  # Command shortcuts
β”œβ”€β”€ setup.sh                  # Setup script
β”œβ”€β”€ requirements.txt          # Dependencies
└── pyproject.toml            # Package config

How It Works

The Generative Verifier Paradigm

Instead of training a separate classification head, we formulate verification as conditional generation:

Input:  "Problem: What is 2+2?\n\nSolution:\nStep 1: 2+2 = 4\n<|verify|>"
Target: "+"  (or "-" for incorrect steps)

Advantages:

  • Leverages pretrained language modeling capabilities
  • No architectural modifications needed
  • Compatible with standard fine-tuning pipelines
  • Enables probability-based confidence scores

Step-wise Scoring Pipeline

def score_solution(problem, solution):
    steps = solution.split("\n")
    context = f"Problem: {problem}\n\nSolution:"
    scores = []

    for step in steps:
        # Get P("+") after <|verify|> token
        logits = model(context + step + "<|verify|>")
        p_correct = softmax(logits)["+"] / (softmax(logits)["+"] + softmax(logits)["-"])
        scores.append(p_correct)
        context += f"\n{step}"

    # Aggregate: product (cumulative) or min (weakest link)
    return aggregate(scores)

MCTS with PRM Value Function

Our MCTS implementation uses:

  • Prior Policy: Log-probabilities from the base model's generation
  • Value Function: PRM scores (product aggregation)
  • UCB Selection: Balances exploitation (high value) and exploration (low visits)
                    Root (Problem)
                    /     |     \
               Step1a  Step1b  Step1c   ← Expand with logprob priors
                 |        |        |
              [0.92]   [0.87]   [0.65]  ← Evaluate with PRM
                 |
              Step2a  Step2b            ← Continue search
                ...

Configuration

Training Configuration (configs/default.yaml)

project:
  name: "prm-math-qwen"
  seed: 42
  output_dir: "./checkpoints"
  logging_dir: "./logs"

data:
  dataset_name: "peiyi9979/Math-Shepherd"
  max_samples: 50000
  balance_positives: true   # Balance +/- labels

model:
  base_model: "Qwen/Qwen2.5-Math-1.5B-Instruct"
  max_seq_length: 2048
  load_in_4bit: true

lora:
  r: 16
  lora_alpha: 16
  lora_dropout: 0.0
  target_modules:
    - q_proj
    - k_proj
    - v_proj
    - o_proj
    - gate_proj
    - up_proj
    - down_proj

training:
  batch_size: 8
  gradient_accumulation_steps: 4
  learning_rate: 2.0e-4
  num_train_epochs: 1
  warmup_ratio: 0.03
  response_template: "<|verify|>"
  save_steps: 100
  logging_steps: 10

inference:
  n_candidates: 16
  temperature: 0.8
  aggregation: "product"  # or "min"

Experimental Results

We evaluated our PRM on two benchmarks: GSM8K (grade school math) and MATH-500 (competition-level problems). The results demonstrate the power of test-time scaling, especially on harder problems.

GSM8K Results (100 problems, 16 candidates)

Method Accuracy vs Pass@1
Pass@1 (baseline) 82.0% -
Majority@16 87.0% +5.0%
PRM Rerank@16 81.0% -1.0%
PRM-Weighted@16 83.0% +1.0%
MCTS@1 78.0% -4.0%
MCTS@5 85.0% +3.0%
MCTS@10 86.0% +4.0%
MCTS@20 83.0% +1.0%
MCTS@50 83.0% +1.0%

Key Finding: On GSM8K, the base model is already strong (82% Pass@1). Majority voting provides the best improvement (+5%), while MCTS@10 achieves 86% accuracy with +5% improvement over PRM Rerank.

MATH-500 Results (100 problems, 16 candidates)

Method Accuracy vs Pass@1
Pass@1 (baseline) 40.0% -
Majority@16 44.0% +4.0%
PRM Rerank@16 44.0% +4.0%
PRM-Weighted@16 44.0% +4.0%
MCTS@1 36.0% -4.0%
MCTS@5 48.0% +8.0%
MCTS@10 49.0% +9.0%
MCTS@20 54.0% +14.0%
MCTS@50 51.0% +11.0%

Key Finding: On competition-level MATH-500 problems, MCTS dramatically outperforms all other methods. MCTS@20 achieves 54% accuracyβ€”a massive +10% improvement over Majority/PRM Rerank, demonstrating that test-time scaling is most beneficial on harder problems where the model needs guided exploration.

Why MCTS Excels on Hard Problems

                    GSM8K (Easy)              MATH-500 (Hard)
                    ────────────              ───────────────
Pass@1              β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ 82%          β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ 40%
Majority@16         β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ 87%         β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ 44%
PRM Rerank@16       β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ 81%          β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ 44%
MCTS@20             β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ 83%         β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ 54%  ← +10% gain!

The pattern is clear: harder problems benefit more from intelligent search. When the base model struggles (40% on MATH-500), MCTS's ability to explore multiple reasoning paths and backtrack from errors becomes crucial.

Scaling Behavior

MCTS shows optimal performance at moderate simulation counts:

MATH-500 Accuracy
     β”‚
 54% β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β—β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ MCTS@20 (Best)
     β”‚              ●     ●
 50% β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β—β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β—β”€β”€β”€β”€β”€β”€β”€β”€β”€ MCTS@10, MCTS@50
     β”‚        ●
 48% β”œβ”€β”€β”€β”€β”€β”€β—β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ MCTS@5
     β”‚
 44% β”œβ”€β”€β”€β”€β—β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ Majority/PRM@16
     β”‚
 40% β”œβ”€β”€β—β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ Pass@1
     │●
 36% β”œβ—β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ MCTS@1
     β”‚
     └───┬────┬────┬────┬────┬────┬──►
         1    5   10   20   50  100
              MCTS Simulations

Insight: More simulations help up to a point (MCTS@20), after which returns diminish. This suggests an optimal compute budget exists for each problem difficulty level.

Key Takeaways

  1. Test-time scaling works: Spending more compute at inference consistently improves accuracy
  2. Method selection matters: Simple majority voting wins on easy problems; MCTS excels on hard problems
  3. MCTS sweet spot: 10-20 simulations provides the best accuracy/compute tradeoff
  4. PRM enables search: Without step-level scores, tree search wouldn't know which paths to explore
  5. Harder problems = bigger gains: The gap between Pass@1 and MCTS grows with problem difficulty

Troubleshooting

CUDA Out of Memory

# Reduce batch size
training:
  batch_size: 4
  gradient_accumulation_steps: 8

Unsloth Installation Issues

# For older GPUs (compute capability < 8.0)
pip install "unsloth[cu121-torch220] @ git+https://github.com/unslothai/unsloth.git"

vLLM Compatibility

# If vLLM fails, use transformers backend
python scripts/inference.py --model_path checkpoints/merged_model  # No --use_vllm flag

Dataset Access

# Ensure HuggingFace access
export HF_DATASETS_OFFLINE=0
huggingface-cli login  # If using gated datasets

References

Papers

Libraries

  • Unsloth - Efficient fine-tuning
  • TRL - Transformer Reinforcement Learning
  • vLLM - High-throughput inference

Citation

If you use this code in your research, please cite:

@software{prm_math,
  title = {PRM-Math: Advanced Test-Time Scaling with Process Reward Models},
  year = {2025},
  url = {https://github.com/yourusername/PRM-Math}
}

License

MIT License - see LICENSE for details.


Note: This is a research project. Results may vary based on hardware, random seeds, and hyperparameters. For production use, consider additional validation and testing.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors