This repository extends the diffusion-classifier project with learnable template functionality for improved text-conditional image classification using diffusion models.
This repository also contains two complementary approaches for improving diffusion-based image classification through hierarchical clustering: Hierarchical Clustering and Beam Search. Both methods leverage CLIP embeddings to organize class labels into semantic hierarchies and reduce computational costs during inference.
This extension enhances the original diffusion-classifier by replacing static text prompts with learnable template embeddings. Instead of using fixed prompts like "a photo of a [class]", the system learns optimal template representations that are specifically tuned for classification tasks using diffusion models. The project implements two main approaches:
- Standard Diffusion Classification: Uses pre-defined text prompts to classify images based on diffusion model reconstruction errors
- Learnable Templates: Automatically learns optimal text templates that improve classification performance
Moreover traditional diffusion-based classification evaluates all classes simultaneously, which can be computationally expensive for datasets with many classes. This project introduces two hierarchical approaches:
- Hierarchical Clustering : Uses a tree-based approach where classification decisions are made level by level, progressively narrowing down candidates.
- Beam Search : Maintains multiple candidate hypotheses (beams) at each hierarchical level and selects the top-k most promising paths.
Both approaches use CLIP embeddings to build semantic hierarchies of class labels, allowing for more efficient and interpretable classification.
Create a conda environment with the following command:
conda env create -f environment.ymlRun classification using pre-defined prompts:
python diffusion_classifier.py \
--dataset cifar10 \
--split test \
--prompt_path prompts/cifar10_prompts.csv \
--to_keep 5 1 \
--n_samples 50 500 \
--loss l1 \
--n_trials 1 \
--samples_per_class 100| Parameter | Description |
|---|---|
--to_keep |
Number of classes to keep at each stage |
--n_samples |
Number of diffusion timesteps to sample at each stage |
--n_trials |
Number of times each sample is evaluated during the experiment |
--samples_per_class |
Create balanced test subset |
Train optimal templates for a specific dataset:
python learnable_templates.py \
--dataset cifar10 \
--split train \
--version 2-0 \
--dtype float16 \Evaluation with Learned Templates
python diffusion_classifier.py \
--dataset cifar10 \
--split test \
--to_keep 5 1 \
--n_samples 50 500 \
--loss l1 \
--n_trials 1 \
--samples_per_class 100 \
--template_path ./templates/prompt_learner{i}.ptpython clustering_diffusion_classifier.py \
--dataset cifar10 \
--split test \
--n_trials 1 \
--loss l1 \
--samples_per_class 10 \
--prompt_path prompts/cifar10_prompts.csv \
--use_clustering \
--cluster_depth 3python python beam_search_diffusion_classifier.py \
--dataset cifar10 \
--split test \
--n_trials 1 \
--loss l1 \
--samples_per_class 10 \
--prompt_path prompts/cifar10_prompts.csv \
--use_clustering \
--beam_width 2Both approaches start by building a semantic hierarchy:
- Extract Class Names: Parse unique class names from the prompt CSV
- CLIP Embedding: Generate CLIP text embeddings for each class name
- Hierarchical Clustering: Use agglomerative clustering to build a tree structure
- Centroid Calculation: For each internal node, identify the centroid class (most representative)
- Start from the root with all classes
- At each depth level:
- Get clusters containing current candidates
- Evaluate representative classes from each cluster
- Select the best cluster (lowest diffusion error)
- Continue with classes from the selected cluster
- Final evaluation on remaining candidates
- Maintain multiple candidate paths (beams) simultaneously
- At each depth level:
- For each beam, evaluate cluster representatives
- Generate new beams from all possible cluster choices
- Keep top-k beams based on cumulative scores
- Final evaluation combines candidates from all surviving beams
Each run generates comprehensive analysis in the output folder:
confusion_matrix.png: Visual confusion matrixclassification_report.txt: Detailed per-class metricsresults_summary.txt: Overall performance summaryhierarchical_tree_labels_clip.pkl: Cached tree structuredepth_error_histogram.png: Distribution of errors by hierarchical depth (hierarchical clustering only)depth_error_stats.txt: Detailed depth error analysis (hierarchical clustering only)