diff --git a/mimi_codebook_tsne.png b/mimi_codebook_tsne.png new file mode 100644 index 0000000..1c2a428 Binary files /dev/null and b/mimi_codebook_tsne.png differ diff --git a/visualize_mimi_codebooks.py b/visualize_mimi_codebooks.py new file mode 100644 index 0000000..c05ac98 --- /dev/null +++ b/visualize_mimi_codebooks.py @@ -0,0 +1,85 @@ +""" +Visualize Mimi codec codebook embeddings in 2D using t-SNE. + +Downloads kyutai/mimi from HuggingFace and plots the codebook embeddings +from mimi.quantizer.acoustic_residual_vector_quantizer.layers[q].codebook.embed +reduced to 2 dimensions via t-SNE for selected quantizer layers. +""" + +import torch +import numpy as np +import matplotlib.pyplot as plt +from sklearn.manifold import TSNE +from transformers import AutoModel + +QUANTIZER_INDICES = [0, 1, 2, 3, 4, 8, 16, 30] + + +def main(): + print("Loading kyutai/mimi from HuggingFace...") + model = AutoModel.from_pretrained("kyutai/mimi", trust_remote_code=True) + model.eval() + + # Inspect quantizer structure + quantizer = model.quantizer.acoustic_residual_vector_quantizer + num_layers = len(quantizer.layers) + print(f"Number of acoustic quantizer layers: {num_layers}") + + # Collect embeddings for each requested quantizer index + embeddings = {} + for q in QUANTIZER_INDICES: + if q >= num_layers: + print(f"WARNING: quantizer index {q} out of range (max {num_layers - 1}), skipping") + continue + embed = quantizer.layers[q].codebook.embed.detach().cpu().numpy() + print(f" q={q}: codebook shape = {embed.shape}") + # embed shape is typically (1, vocab_size, dim) or (vocab_size, dim) + if embed.ndim == 3: + embed = embed.squeeze(0) + embeddings[q] = embed + + if not embeddings: + print("No embeddings found!") + return + + vocab_size, dim = next(iter(embeddings.values())).shape + print(f"\nVocab size: {vocab_size}, Embedding dim: {dim}") + print("Running t-SNE for each quantizer layer...\n") + + # Create subplot grid + n = len(embeddings) + ncols = 4 + nrows = (n + ncols - 1) // ncols + fig, axes = plt.subplots(nrows, ncols, figsize=(5 * ncols, 5 * nrows)) + axes = np.array(axes).flatten() + + for idx, (q, embed) in enumerate(sorted(embeddings.items())): + print(f" t-SNE for q={q}...") + tsne = TSNE(n_components=2, random_state=42, perplexity=30, max_iter=1000) + reduced = tsne.fit_transform(embed) + + ax = axes[idx] + ax.scatter(reduced[:, 0], reduced[:, 1], s=3, alpha=0.5) + ax.set_title(f"Quantizer q={q}", fontsize=13) + ax.set_xlabel("t-SNE 1") + ax.set_ylabel("t-SNE 2") + ax.tick_params(labelsize=8) + + # Hide unused subplots + for idx in range(n, len(axes)): + axes[idx].set_visible(False) + + fig.suptitle( + f"Mimi Acoustic Codebook Embeddings (vocab={vocab_size}, dim={dim}) — t-SNE 2D", + fontsize=15, + y=1.01, + ) + plt.tight_layout() + out_path = "mimi_codebook_tsne.png" + fig.savefig(out_path, dpi=180, bbox_inches="tight") + print(f"\nSaved plot to {out_path}") + plt.close(fig) + + +if __name__ == "__main__": + main()