Skip to content

Latest commit

 

History

History

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 
 
 
 
 

README.md

LLaDA

📄 Paper: Large Language Diffusion Models | 💻 Code: github.com/ML-GSAI/LLaDA

Resources and examples for training (finetuning & pretraining) and evaluating diffusion language models LLaDA.

Table of Contents

Files

# Pipeline modules relevant to LLaDA
dllm/pipelines/llada
├── __init__.py                     # Package initialization
├── models/
│   ├── configuration_lladamoe.py   # LLaDA-MoE model configuration
│   ├── configuration_llada.py      # LLaDA model configuration
│   ├── modeling_lladamoe.py        # LLaDA-MoE model architecture
│   └── modeling_llada.py           # LLaDA model architecture
├── eval.py                         # Evaluation module
├── sampler.py                      # Inference module
└── trainer.py                      # Training module (pretraining, SFT, and GRPO/RL)

# Example entry points for training / inference / evaluation
examples/llada
├── chat.py                         # Interactive inference example
├── eval.sh                         # Automatic evaluation example
├── grpo.py                         # GRPO/RL training entry point
├── sample.py                       # Inference example
├── pt.py                           # Pretraining example
├── README.md                       # Documentation (you are here)
└── sft.py                          # Supervised finetuning example

Training

Read Useful tips for training and (optional) Slurm setup before training.

MoE checkpoints: For models like LLaDA-MoE-7B-A1B-Base, set "model_type" to "lladamoe" in the checkpoint’s config.json:

SFT

For example, to SFT LLaDA-8B-Base on the alpaca dataset for instruction following on 8 GPUs, run:

accelerate launch \
    --config_file scripts/accelerate_configs/fsdp.yaml \
    examples/llada/sft.py \
    --model_name_or_path "GSAI-ML/LLaDA-8B-Base" \
    --dataset_args "tatsu-lab/alpaca" \
    --max_length 1024 \
    --num_train_epochs 5 \
    --learning_rate 2e-5 \
    --per_device_train_batch_size 4 \
    --per_device_eval_batch_size 4 \
    --output_dir ".models/LLaDA-8B-Base/alpaca"

If you are using slurm and want to train across, for example, 2 nodes (16 GPUs total), run:

sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
    --accelerate_config "fsdp" \
    --script_path "examples/llada/sft.py" \
    --model_name_or_path "GSAI-ML/LLaDA-8B-Base" \
    --dataset_args "tatsu-lab/alpaca" \
    --max_length 1024 \
    --num_train_epochs 5 \
    --learning_rate 2e-5 \
    --per_device_train_batch_size 4 \
    --per_device_eval_batch_size 4 \
    --output_dir ".models/LLaDA-8B-Base/alpaca"

Reproducing LLaDA-8B-Instruct with SFT

Though LLaDA is trained on proprietary data, we tried our best to reproduce LLaDA-8B-Instruct by finetuning LLaDA-8B-Base with SFT on the allenai/tulu-3-sft-mixture dataset:

# Preprocessing SFT data (optional, but can avoid redundant preprocessing for multi-node training)
python dllm/tools/preprocess_sft_dataset.py \
    --model_name_or_path "GSAI-ML/LLaDA-8B-Base" \
    --sft_map_fn_path "dllm.utils.default_sft_map_fn" \
    --dataset_args "allenai/tulu-3-sft-mixture" \
    --output_dir ".data/sft/llada/tulu-3-sft-mixture" \
    --num_proc 64

# Train on 24*8=192 A100s with FSDP, take about 8 hours
sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \
    --accelerate_config "fsdp" \
    --script_path "examples/llada/sft.py" \
    --model_name_or_path "GSAI-ML/LLaDA-8B-Base" \
    --dataset_args ".data/sft/llada/tulu-3-sft-mixture" \
    --load_preprocessed_data True \
    --max_length 1024 \
    --num_train_epochs 5 \
    --learning_rate 2e-5 \
    --per_device_train_batch_size 4 \
    --per_device_eval_batch_size 4 \
    --output_dir ".models/LLaDA-8B-Base/tulu-3-sft-mixture"

RL (GRPO)

We adapt GRPO (Group Relative Policy Optimization) for masked diffusion language models via DiffuGRPOTrainer, which replaces autoregressive generation with iterative denoising. The implementation follows the d1/diffu-grpo reference.

For example, to run GRPO on LLaDA-8B-Instruct with gsm8k on 1 GPU:

accelerate launch \
    --config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
    examples/llada/grpo.py \
    --model_name_or_path "GSAI-ML/LLaDA-8B-Instruct" \
    --dataset gsm8k \  # supported: gsm8k, countdown, sudoku, math, code
    --num_train_epochs 1 \
    --output_dir ".models/LLaDA-8B-Instruct/gsm8k-grpo"

To train with LoRA on 8 GPUs using DeepSpeed ZeRO-2:

accelerate launch \
    --config_file scripts/accelerate_configs/zero2.yaml \
    examples/llada/grpo.py \
    --model_name_or_path "GSAI-ML/LLaDA-8B-Instruct" \
    --lora_r 128 --lora_alpha 64 \
    --dataset gsm8k \
    --num_train_epochs 10 --learning_rate 3e-6 \
    --num_generations 6 --per_device_train_batch_size 6 \
    --beta 0.04 --epsilon 0.5 \
    --output_dir ".models/LLaDA-8B-Instruct/gsm8k-grpo"

Key diffusion-specific arguments: --block_size, --steps, --remasking, --p_mask_prompt. Key GRPO arguments: --beta, --epsilon, --num_generations, --num_iterations.

Pretraining

Pretrain on mlfoundations/dclm-baseline-1.0 from scratch using 192 GPUs (24x8) and FSDP:

sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \
    --accelerate_config "fsdp" \
    --script_path "examples/llada/pt.py" \
    --model_name_or_path "GSAI-ML/LLaDA-8B-Base" \
    --dataset_args "mlfoundations/dclm-baseline-1.0" \
    --max_length 1024 \
    --max_steps 2000 \
    --learning_rate 1e-4 \
    --per_device_train_batch_size 4 \
    --per_device_eval_batch_size 4 \
    --output_dir ".models/LLaDA-8B-Base/dclm-baseline-1.0"

Inference

We support batch inference for standard sampling and infilling:

python examples/llada/sample.py --model_name_or_path "GSAI-ML/LLaDA-8B-Instruct"

We also support interactive multi-turn dialogue with visualization:

python examples/llada/chat.py --model_name_or_path "GSAI-ML/LLaDA-8B-Instruct"

Evaluation

Read (optional) Evaluation setup before running evaluation.

For example, to evaluate LLaDA-8B-Instruct on gsm8k using 4 GPUs, run:

# Use model_args to adjust the sampling arguments for evaluation.
accelerate launch --num_processes 4 \
    dllm/pipelines/llada/eval.py \
    --tasks "gsm8k_cot" \
    --model "llada" \
    --apply_chat_template \
    --num_fewshot 5 \
    --model_args "pretrained=GSAI-ML/LLaDA-8B-Instruct,max_new_tokens=512,steps=512,block_size=512,cfg_scale=0.0,suppress_tokens=[],begin_suppress_tokens=[126081;126348]"

To automatically evaluate LLaDA-8B-Base and LLaDA-8B-Instruct on all benchmarks, run:

bash examples/llada/eval.sh --model_name_or_path GSAI-ML/LLaDA-8B-Instruct --instruct True
bash examples/llada/eval.sh --model_name_or_path GSAI-ML/LLaDA-8B-Base --instruct False

For Fast-dLLM sampling and evaluation with LLaDA, see the Fast-dLLM README.

Evaluation results

Results (Reproduced) are evaluated using our framework, while results (Official) come from the original paper. All evaluation settings follow the configurations in the LLaDA repository, with minor adjustments.

              MMLU BBH ARC‑C Hellaswag TruthfulQA WinoGrande PIQA GSM8K Math GPQA HumanEval MBPP CEval CMMLU
LLaDA-8B-Base (Official) 65.9 49.7 45.9 70.5 46.1 74.8 73.6 70.3 31.4 25.2 35.4 40.0 70.5 69.9
LLaDA-8B-Base (Reproduced) 65.9 47.2 44.1 69.2 45.6 70.4 70.7 70.7 32.4 31.9 32.9 38.8 70.4 69.8

Table 1. Evaluation results of LLaDA-8B-Base .

                MMLU MMLU‑Pro ARC‑C Hellaswag GSM8K Math GPQA HumanEval MBPP
LLaDA-8B-Instruct (Official) 65.5 37.0 88.5 74.6 69.4 31.9 33.3 49.4 41.0
LLaDA-8B-Instruct (Reproduced) 69.8 36.2 86.4 76.7 74.7 31.9 30.6 47.0 40.0

Table 2. Evaluation results of LLaDA-8B-Instruct .