A transformer model for computer-assisted multitrack music composition.
- Fill in missing bars while preserving your existing arrangement
- Generate new tracks from scratch, conditioned on musical attributes
- Steer the output by controlling note density, polyphony, and note duration — globally or per bar
- Integrate with your DAW via a real-time OSC server
- One-line setup — load pretrained models from HuggingFace Hub, no compiler needed
pip install "midigpt[inference]"Pre-built wheels for CPython 3.10–3.12 on Linux (x86_64), macOS (x86_64 + arm64), and Windows (AMD64). No compiler needed.
| Extra | What it adds |
|---|---|
inference |
torch>=2.0, tqdm, huggingface_hub |
train |
PyTorch Lightning, HuggingFace datasets, pyarrow, python-dotenv |
realtime |
python-osc, Flask, Flask-SocketIO |
http |
FastAPI, uvicorn |
dev |
pytest, ruff, mypy |
all |
realtime + train |
Load a pretrained model from HuggingFace Hub and generate music in four lines:
from midigpt import Score, Track, Bar
from midigpt.inference import InferenceEngine, GenerationRequest, InferenceConfig, TrackPrompt
engine = InferenceEngine.from_pretrained("yellow")
# 4-bar score with one empty melodic track
score = Score(tracks=[Track(bars=[Bar() for _ in range(4)])])
result = engine.session(
score,
GenerationRequest(
tracks=[TrackPrompt(id=0, bars=[0, 1, 2, 3])],
config=InferenceConfig(model_dim=4, mask_mode="attention"),
),
).run()
total = sum(len(b.notes) for t in result.tracks for b in t.bars)
print(f"Generated {total} notes")
result.to_midi("output.mid")The model is downloaded once and cached by huggingface_hub in ~/.cache/huggingface/hub/.
| Name | num_bars_map |
Infill | Attributes | Download |
|---|---|---|---|---|
yellow |
4, 8 | yes | note density, polyphony (min/max), note duration (min/max) | yellow.pt |
ghost |
4, 8, 12, 16 | yes | note density, polyphony (min/max), note duration (min/max) | coming soon |
expressive |
4, 8 | yes | note density, polyphony (min/max), note duration (min/max) | coming soon |
model_dim in InferenceConfig is the context window in bars, not a vocabulary dimension — pass a value from the model's num_bars_map. expressive additionally encodes sub-grid timing via delta tokens.
# By name (downloads from Metacreation/MIDI-GPT on HuggingFace Hub)
engine = InferenceEngine.from_pretrained("yellow") # or "ghost", "expressive"
# From a local .pt bundle
engine = InferenceEngine.from_checkpoint("path/to/model.pt")score = Score.from_midi("my_song.mid")
request = GenerationRequest(
tracks=[
TrackPrompt(id=0, bars=[4, 5, 6, 7]), # bars to regenerate
TrackPrompt(id=1, bars=[], ignore=True), # leave track 1 unchanged
],
config=InferenceConfig(temperature=1.0, top_p=0.95, model_dim=8),
)
result = engine.session(score, request).run()
result.to_midi("output.mid")request = GenerationRequest(
tracks=[
TrackPrompt(
id=0,
bars=[],
autoregressive=True,
attributes={"max_polyphony": 3}, # quantized attribute level
controls={"time_signature": 0}, # index into encoder TS list
),
],
config=InferenceConfig(temperature=1.0, model_dim=8, polyphony_hard_limit=4),
)
result = engine.session(score, request).run()| Class | Module | Purpose |
|---|---|---|
InferenceEngine |
midigpt.inference |
Top-level loader and session factory |
GenerationRequest |
midigpt.inference |
Bundle of per-track prompts and config |
TrackPrompt |
midigpt.inference |
Per-track bars, mode, attributes, controls |
InferenceConfig |
midigpt.inference |
Temperature, sampling filters, step planner |
SamplingSession |
midigpt.inference |
Token-level sampling loop (returned by session()) |
| Field | Type | Default | Meaning |
|---|---|---|---|
id |
int | — | Track index in the score |
bars |
list[int] | — | Bars to generate |
autoregressive |
bool | False |
Generate from scratch (no per-bar prompt) |
ignore |
bool | False |
Omit this track from the token stream |
mask_bars |
list[int] | [] |
Bars hidden with MASK_BAR (disjoint from bars) |
attributes |
dict[str,int] | {} |
Quantized attribute overrides |
controls |
dict[str,Any] | {} |
Token locks e.g. {"time_signature": 0} |
bar_attributes |
dict[int,dict] | {} |
Per-bar attribute overrides (absolute bar index) |
bar_controls |
dict[int,dict] | {} |
Per-bar control overrides (absolute bar index) |
InferenceConfig exposes a four-stage logit-filtering pipeline (top_k → top_p → mask_k → mask_p):
| Field | Default | Meaning |
|---|---|---|
top_k |
0 (off) |
Keep top-k highest-probability tokens |
top_p |
1.0 (off) |
Nucleus: keep the smallest set summing to ≥ top_p |
mask_k |
0 (off) |
Remove the top-k most-likely tokens (novelty pressure) |
mask_p |
0.0 (off) |
Anti-nucleus: remove tokens summing to ≥ mask_p from the top |
A small mask_k=1 or mask_p=0.3 pushes the model off its most-confident picks — useful for getting diverse outputs when novelty_check=True.
Control how future bars appear in the context window:
| Mode | Behaviour |
|---|---|
"token" |
Encoder emits a MaskBar token (requires vocab support) |
"attention" |
Future bars zeroed in the KV cache via exact span masking |
"attention_approx" |
Single prefill mask + KV surgery; cheaper than "attention" |
"attention_skip" |
Future tokens filtered from input; position_ids passed explicitly |
"remove" |
Future bars omitted entirely from the token stream |
Set via InferenceConfig(mask_mode="attention"). "attention" works on all encoders; "token" requires the encoder vocab to include a MaskBar domain.
Introspect available controls at runtime:
engine._analyzer.attribute_sizes() # {"note_density": 10, "min_polyphony": 10, ...}
engine._analyzer.attribute_value_labels() # {"note_density": ["very sparse", ...], ...}
engine._analyzer.attribute_track_types() # {"note_density": "melodic", ...}Pass quantized levels (integers in [0, size)) in TrackPrompt.attributes.
python -m midigpt.training.preprocess \
--parquet /data/train/*.parquet \
--checkpoint models/yellow.ptBuilds a valid-index cache so dataset initialization is instant on subsequent runs. Cached in ~/.midigpt/ (override with MIDIGPT_CACHE).
python -m midigpt.training.trainer \
--config models/train_config.json \
--train-data /data/train/*.parquet \
--eval-data /data/valid/*.parquet \
--output-dir checkpoints/run_001from midigpt.training.trainer import TrainConfig, train
config = TrainConfig.from_file("models/train_config.json")
train(config, train_path="/data/train/00000.parquet", eval_path="/data/valid/00000.parquet")train() uses PyTorch Lightning and writes a packed .pt bundle at the end of training containing weights, architecture config, and encoder config.
| Field | Default | Notes |
|---|---|---|
n_embd / n_layer / n_head |
512 / 6 / 8 |
Model architecture |
max_seq_len |
2048 |
Token sequence cap |
infill_probability |
0.75 |
Fraction of samples trained with FillIn tokens |
mask_apply_probability |
0.5 |
Fraction of samples with MASK_BAR applied |
precision |
"fp16" |
"fp16", "bf16", or "fp32" |
logger |
"none" |
"tensorboard", "wandb", or "none" |
num_workers |
0 |
Must be 0 — the C++ MIDI parser is not fork-safe |
pip install "midigpt[http]"
# From a local checkpoint
midigpt-http --ckpt models/yellow.pt --port 8000
# From HuggingFace Hub (by name or repo ID)
midigpt-http --pretrained yellow --port 8000
midigpt-http --pretrained Metacreation/MIDI-GPT --hf-filename yellow.pt --port 8000A stateless REST API — every request carries the full score and generation parameters. The interactive API docs are available at http://localhost:8000/docs.
| Endpoint | Description |
|---|---|
GET /health |
Liveness probe |
GET /info |
Model capabilities and attribute sizes |
POST /generate |
{score, request} → {score, timing} |
# Score: 1 melodic track, 4 empty bars — generate all 4 from scratch
curl -s -X POST http://localhost:8000/generate \
-H "Content-Type: application/json" \
-d '{
"score": {
"resolution": 480, "tempo": 500000,
"tracks": [{
"instrument": 0, "track_type": "melodic",
"bars": [
{"ts_numerator": 4, "ts_denominator": 4, "notes": []},
{"ts_numerator": 4, "ts_denominator": 4, "notes": []},
{"ts_numerator": 4, "ts_denominator": 4, "notes": []},
{"ts_numerator": 4, "ts_denominator": 4, "notes": []}
]
}]
},
"request": {
"tracks": [{"id": 0, "bars": [0, 1, 2, 3]}],
"config": {"model_dim": 4}
}
}' | jq .scoreUse --device cuda, --device mps, or --device auto (default) to select the compute device.
pip install "midigpt[realtime]"
midigpt-server --ckpt models/yellow.pt --port 7400Listens for OSC messages on a UDP port and streams generated notes back in real time. Generation is triggered bar-by-bar via /midigpt/bar/end on a background thread.
Selected OSC addresses:
| Address | Direction | Description |
|---|---|---|
/midigpt/session/init |
in | Start a new session |
/midigpt/track/create |
in | Register a track |
/midigpt/note |
in | Push an incoming note |
/midigpt/bar/end |
in | Signal bar end (triggers generation) |
/midigpt/param/set |
in | Adjust sampling parameters at runtime |
/midigpt/attr/set |
in | Set attribute overrides |
/midigpt/generated/note |
out | Emit a generated note |
/midigpt/generated/features |
out | Per-bar statistics |
/midigpt/capabilities |
out | Attribute support for the loaded checkpoint |
git clone https://github.com/Metacreation-Lab/MIDI-GPT.git
cd MIDI-GPT
pip install -e ".[inference,dev]" # compiles the C++ extension in-placePrerequisites: Python 3.10+, CMake 3.21+, a C++20 compiler.
# Python
pytest tests/python/
pytest tests/python -m "not slow and not inference" # CI subset (no model needed)
# C++
cmake -S . -B build_cpp -DCMAKE_BUILD_TYPE=Release
cmake --build build_cpp -j
ctest --test-dir build_cpp --output-on-failureruff check src/ tests/ # lint
ruff format src/ tests/ # formatpre-commit runs both automatically on commit:
pip install pre-commit && pre-commit installTag a commit vX.Y.Z → .github/workflows/wheels.yml builds wheels on Linux / macOS / Windows × Python 3.10–3.12, drafts a GitHub Release, and publishes to PyPI via OIDC Trusted Publishing.
Set MIDIGPT_LOG_LEVEL=DEBUG (or a numeric level) before importing. Accepts both string names (DEBUG, INFO, WARNING) and integers.
@misc{pasquier2025midigptcontrollablegenerativemodel,
title={MIDI-GPT: A Controllable Generative Model for Computer-Assisted Multitrack Music Composition},
author={Philippe Pasquier and Jeff Ens and Nathan Fradet and Paul Triana and Davide Rizzotti and Jean-Baptiste Rolland and Maryam Safi},
year={2025},
eprint={2501.17011},
archivePrefix={arXiv},
primaryClass={cs.SD},
url={https://arxiv.org/abs/2501.17011},
}MIT License — Copyright (c) 2026 Metacreation Lab. See LICENSE.