Measure KL divergence between MLX language model output distributions. Compare a reference model (typically full-precision) against one or more quantized variants to see how much quantization shifts the model's probability distributions.
python -m venv .venv
source .venv/bin/activate
pip install -e .mlx-kld \
--reference /path/to/full-precision-model \
--compare /path/to/quantized-model \
--prompts "What is the capital of France?" "Explain photosynthesis." \
--top-k 10| Flag | Description |
|---|---|
--reference |
Path or HF repo for the reference model |
--compare |
One or more paths/HF repos for comparison models |
--prompts |
One or more prompt strings |
--prompts-file |
Text file with one prompt per line |
--output |
Path prefix for JSON results (see below) |
--save-reference PATH |
Save reference logits to a .npz file after running |
--load-reference PATH |
Load previously saved reference logits; skips running the reference model |
--top-k N |
Show the N most divergent tokens per model |
--no-chat-template |
Tokenize raw prompts instead of applying the model's chat template |
Compare two quants against the same reference in one run:
mlx-kld \
--reference /path/to/Qwen3.5-27B-FULL \
--compare /path/to/Qwen3.5-27B-oQ5 /path/to/Qwen3.5-27B-5bit \
--prompts-file prompts.txt \
--output resultsThis produces results_Qwen3.5-27B-oQ5.json and results_Qwen3.5-27B-5bit.json, plus a side-by-side summary table in the terminal.
Save reference logits once, then run additional comparisons without reloading the reference model:
# First run — saves reference logits to disk
mlx-kld \
--reference /path/to/Qwen3.5-27B-FULL \
--compare /path/to/Qwen3.5-27B-oQ5 \
--prompts-file prompts.txt \
--save-reference qwen27b_ref \
--output results
# Later — test a new quant, reference model never loaded
mlx-kld \
--load-reference qwen27b_ref \
--compare /path/to/Qwen3.5-27B-oQ4 \
--output resultsThe .npz cache stores the full log-prob distributions for every token of every prompt. Prompts are baked into the cache so --prompts / --prompts-file are ignored when using --load-reference.
With a single --compare model, --output results saves to results.json. With multiple models it uses the last component of each model path as a slug: results_<modelname>.json.
- Loads the reference model, runs a forward pass on each prompt, and stores the log-softmax output distributions as numpy arrays
- Unloads the reference model to free memory (or skips this entirely with
--load-reference) - For each comparison model: loads it, runs the same forward passes, computes per-token KL divergence against the stored reference distributions, then unloads before loading the next
- Reports per-model statistics, and a side-by-side summary table when comparing multiple models
Only one model is in memory at a time, which matters when the reference is a large full-precision model.
KL divergence at each token position in the prefill (forward pass on the prompt tokens):
KL(P_ref || P_cmp) = sum(P_ref * (log(P_ref) - log(P_cmp)))
This tells you how much information is lost at each token position by using the comparison model instead of the reference.
==================================================
KL Divergence Results
==================================================
Reference: /path/to/Qwen3.5-27B-FULL
Compare: /path/to/Qwen3.5-27B-oQ5
==================================================
Prompts: 35
Tokens: 1470
──────────────────────────────────────────────────
Mean KLD: 0.009301
Median KLD: 0.002092
Std KLD: 0.066812
P95 KLD: 0.022045
P99 KLD: 0.078192
Max KLD: 2.033358
(prompt 16, position 70, token ",")
==================================================
When multiple models are compared, a summary table is printed after the individual results:
==================================================
Summary Comparison
==================================================
Model Qwen3.5-27B-oQ5 Qwen3.5-27B-5bit
──────────────────────────────────────────────────────────────────
Mean KLD 0.009301 * 0.010491
Median KLD 0.002092 * 0.002274
Std KLD 0.066812 * 0.114512
P95 KLD 0.022045 * 0.022159
P99 KLD 0.078192 0.065190 *
Max KLD 2.033358 * 4.066809
* = best (lowest) for that metric
==================================================
The --output JSON includes per-token KLD values, token IDs, and decoded token strings for every prompt.