Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added mimi_codebook_tsne.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
85 changes: 85 additions & 0 deletions visualize_mimi_codebooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""
Visualize Mimi codec codebook embeddings in 2D using t-SNE.

Downloads kyutai/mimi from HuggingFace and plots the codebook embeddings
from mimi.quantizer.acoustic_residual_vector_quantizer.layers[q].codebook.embed
reduced to 2 dimensions via t-SNE for selected quantizer layers.
"""

import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from transformers import AutoModel

QUANTIZER_INDICES = [0, 1, 2, 3, 4, 8, 16, 30]


def main():
print("Loading kyutai/mimi from HuggingFace...")
model = AutoModel.from_pretrained("kyutai/mimi", trust_remote_code=True)
model.eval()

# Inspect quantizer structure
quantizer = model.quantizer.acoustic_residual_vector_quantizer
num_layers = len(quantizer.layers)
print(f"Number of acoustic quantizer layers: {num_layers}")

# Collect embeddings for each requested quantizer index
embeddings = {}
for q in QUANTIZER_INDICES:
if q >= num_layers:
print(f"WARNING: quantizer index {q} out of range (max {num_layers - 1}), skipping")
continue
embed = quantizer.layers[q].codebook.embed.detach().cpu().numpy()
print(f" q={q}: codebook shape = {embed.shape}")
# embed shape is typically (1, vocab_size, dim) or (vocab_size, dim)
if embed.ndim == 3:
embed = embed.squeeze(0)
embeddings[q] = embed

if not embeddings:
print("No embeddings found!")
return

vocab_size, dim = next(iter(embeddings.values())).shape
print(f"\nVocab size: {vocab_size}, Embedding dim: {dim}")
print("Running t-SNE for each quantizer layer...\n")

# Create subplot grid
n = len(embeddings)
ncols = 4
nrows = (n + ncols - 1) // ncols
fig, axes = plt.subplots(nrows, ncols, figsize=(5 * ncols, 5 * nrows))
axes = np.array(axes).flatten()

for idx, (q, embed) in enumerate(sorted(embeddings.items())):
print(f" t-SNE for q={q}...")
tsne = TSNE(n_components=2, random_state=42, perplexity=30, max_iter=1000)
reduced = tsne.fit_transform(embed)

ax = axes[idx]
ax.scatter(reduced[:, 0], reduced[:, 1], s=3, alpha=0.5)
ax.set_title(f"Quantizer q={q}", fontsize=13)
ax.set_xlabel("t-SNE 1")
ax.set_ylabel("t-SNE 2")
ax.tick_params(labelsize=8)

# Hide unused subplots
for idx in range(n, len(axes)):
axes[idx].set_visible(False)

fig.suptitle(
f"Mimi Acoustic Codebook Embeddings (vocab={vocab_size}, dim={dim}) — t-SNE 2D",
fontsize=15,
y=1.01,
)
plt.tight_layout()
out_path = "mimi_codebook_tsne.png"
fig.savefig(out_path, dpi=180, bbox_inches="tight")
print(f"\nSaved plot to {out_path}")
plt.close(fig)


if __name__ == "__main__":
main()