Skip to content

TipKnuckle/mlx-kld

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

mlx-kld

Measure KL divergence between MLX language model output distributions. Compare a reference model (typically full-precision) against one or more quantized variants to see how much quantization shifts the model's probability distributions.

Install

python -m venv .venv
source .venv/bin/activate
pip install -e .

Usage

mlx-kld \
  --reference /path/to/full-precision-model \
  --compare /path/to/quantized-model \
  --prompts "What is the capital of France?" "Explain photosynthesis." \
  --top-k 10

Options

Flag Description
--reference Path or HF repo for the reference model
--compare One or more paths/HF repos for comparison models
--prompts One or more prompt strings
--prompts-file Text file with one prompt per line
--output Path prefix for JSON results (see below)
--save-reference PATH Save reference logits to a .npz file after running
--load-reference PATH Load previously saved reference logits; skips running the reference model
--top-k N Show the N most divergent tokens per model
--no-chat-template Tokenize raw prompts instead of applying the model's chat template

Examples

Compare two quants against the same reference in one run:

mlx-kld \
  --reference /path/to/Qwen3.5-27B-FULL \
  --compare /path/to/Qwen3.5-27B-oQ5 /path/to/Qwen3.5-27B-5bit \
  --prompts-file prompts.txt \
  --output results

This produces results_Qwen3.5-27B-oQ5.json and results_Qwen3.5-27B-5bit.json, plus a side-by-side summary table in the terminal.

Save reference logits once, then run additional comparisons without reloading the reference model:

# First run — saves reference logits to disk
mlx-kld \
  --reference /path/to/Qwen3.5-27B-FULL \
  --compare /path/to/Qwen3.5-27B-oQ5 \
  --prompts-file prompts.txt \
  --save-reference qwen27b_ref \
  --output results

# Later — test a new quant, reference model never loaded
mlx-kld \
  --load-reference qwen27b_ref \
  --compare /path/to/Qwen3.5-27B-oQ4 \
  --output results

The .npz cache stores the full log-prob distributions for every token of every prompt. Prompts are baked into the cache so --prompts / --prompts-file are ignored when using --load-reference.

With a single --compare model, --output results saves to results.json. With multiple models it uses the last component of each model path as a slug: results_<modelname>.json.

How it works

  1. Loads the reference model, runs a forward pass on each prompt, and stores the log-softmax output distributions as numpy arrays
  2. Unloads the reference model to free memory (or skips this entirely with --load-reference)
  3. For each comparison model: loads it, runs the same forward passes, computes per-token KL divergence against the stored reference distributions, then unloads before loading the next
  4. Reports per-model statistics, and a side-by-side summary table when comparing multiple models

Only one model is in memory at a time, which matters when the reference is a large full-precision model.

What's measured

KL divergence at each token position in the prefill (forward pass on the prompt tokens):

KL(P_ref || P_cmp) = sum(P_ref * (log(P_ref) - log(P_cmp)))

This tells you how much information is lost at each token position by using the comparison model instead of the reference.

Output

==================================================
  KL Divergence Results
==================================================
  Reference: /path/to/Qwen3.5-27B-FULL
  Compare:   /path/to/Qwen3.5-27B-oQ5
==================================================
  Prompts:     35
  Tokens:      1470
──────────────────────────────────────────────────
  Mean KLD:    0.009301
  Median KLD:  0.002092
  Std KLD:     0.066812
  P95 KLD:     0.022045
  P99 KLD:     0.078192
  Max KLD:     2.033358
             (prompt 16, position 70, token ",")
==================================================

When multiple models are compared, a summary table is printed after the individual results:

==================================================
  Summary Comparison
==================================================
Model                         Qwen3.5-27B-oQ5  Qwen3.5-27B-5bit
──────────────────────────────────────────────────────────────────
Mean KLD                           0.009301 *      0.010491
Median KLD                         0.002092 *      0.002274
Std KLD                            0.066812 *      0.114512
P95 KLD                            0.022045 *      0.022159
P99 KLD                            0.078192        0.065190 *
Max KLD                            2.033358 *      4.066809

  * = best (lowest) for that metric
==================================================

The --output JSON includes per-token KLD values, token IDs, and decoded token strings for every prompt.

Dependencies

About

Tool to test and compare KL Divergence in MLX models

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages