Skip to content

Commit

Permalink
docs: add multiple CIFAR10 examples using Reactant
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 31, 2024
1 parent 3f8b231 commit 00c3206
Show file tree
Hide file tree
Showing 10 changed files with 290 additions and 202 deletions.
4 changes: 2 additions & 2 deletions docs/src/.vitepress/config.mts
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,8 @@ export default defineConfig({
link: "https://github.com/LuxDL/Lux.jl/tree/main/examples/DDIM",
},
{
text: "ConvMixer on CIFAR-10",
link: "https://github.com/LuxDL/Lux.jl/tree/main/examples/ConvMixer",
text: "Different Vision Models on CIFAR-10",
link: "https://github.com/LuxDL/Lux.jl/tree/main/examples/CIFAR10",
},
],
},
Expand Down
6 changes: 3 additions & 3 deletions docs/src/tutorials/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@ const large_models = [
desc: "Train a Diffusion Model to generate images from Gaussian noises."
},
{
href: "https://github.com/LuxDL/Lux.jl/tree/main/examples/ConvMixer",
href: "https://github.com/LuxDL/Lux.jl/tree/main/examples/CIFAR10",
src: "https://datasets.activeloop.ai/wp-content/uploads/2022/09/CIFAR-10-dataset-Activeloop-Platform-visualization-image-1.webp",
caption: "ConvMixer on CIFAR-10",
desc: "Train ConvMixer on CIFAR-10 to 90% accuracy within 10 minutes."
caption: "Vision Models on CIFAR-10",
desc: "Train different vision models on CIFAR-10 to 90% accuracy within 10 minutes."
}
];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DataAugmentation = "88a5189c-e7ff-4f85-ac6b-e6158070f02e"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534"
ImageShow = "4e3cecfd-b093-5904-9786-8bbb286a6a31"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
Expand All @@ -11,10 +12,10 @@ MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
PreferenceTools = "ba661fbb-e901-4445-b070-854aec6bfbc5"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProgressBars = "49802e3a-d2f1-5c88-81d8-b72133a6f568"
ProgressTables = "e0b4b9f6-8cc7-451e-9c86-94c5316e9f73"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Expand All @@ -23,6 +24,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Comonicon = "1.0.8"
ConcreteStructs = "0.2.3"
DataAugmentation = "0.3"
Enzyme = "0.13.14"
ImageCore = "0.10.2"
ImageShow = "0.3.8"
Interpolations = "0.15.1"
Expand All @@ -32,10 +34,9 @@ MLDatasets = "0.7.14"
MLUtils = "0.4.4"
OneHotArrays = "0.2.5"
Optimisers = "0.4.1"
PreferenceTools = "0.1.2"
Printf = "1.10"
ProgressBars = "1.5.1"
Random = "1.10"
Reactant = "0.2.5"
StableRNGs = "1.0.2"
Statistics = "1.10"
Zygote = "0.6.70"
54 changes: 54 additions & 0 deletions examples/CIFAR10/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Train Vision Models on CIFAR-10

✈️ 🚗 🐦 🐈 🦌 🐕 🐸 🐎 🚢 🚚

We have the following scripts to train vision models on CIFAR-10:

1. `simple_cnn.jl`: Simple CNN model with a sequence of convolutional layers.
2. `mlp_mixer.jl`: MLP-Mixer model.
3. `conv_mixer.jl`: ConvMixer model.

To get the options for each script, run the script with the `--help` flag.

> [!NOTE]
> To train the model using Reactant.jl pass in `--backend=reactant` to the script. This is
> the recommended approach to train the models present in this directory.
## Simple CNN

```bash
julia --startup-file=no \
--project=. \
--threads=auto \
simple_cnn.jl \
--backend=reactant
```

On a RTX 4050 6GB Laptop GPU the training takes approximately 3 mins and the final training
and test accuracies are 97% and 65%, respectively.

## MLP-Mixer

## ConvMixer

> [!NOTE]
> This code has been adapted from https://github.com/locuslab/convmixer-cifar10
This is a simple ConvMixer training script for CIFAR-10. It's probably a good starting point
for new experiments on small datasets.

You can get around **90.0%** accuracy in just **25 epochs** by running the script with the
following arguments, which trains a ConvMixer-256/8 with kernel size 5 and patch size 2.

```bash
julia --startup-file=no \
--project=. \
--threads=auto \
conv_mixer.jl \
--backend=reactant
```

### Notes

1. To match the results from the original repo, we need more augmentation strategies, that
are currently not implemented in DataAugmentation.jl.
139 changes: 139 additions & 0 deletions examples/CIFAR10/common.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
using ConcreteStructs, DataAugmentation, ImageShow, Lux, MLDatasets, MLUtils, OneHotArrays,
Printf, ProgressTables, Random
using Reactant, LuxCUDA

@concrete struct TensorDataset
dataset
transform
end

Base.length(ds::TensorDataset) = length(ds.dataset)

function Base.getindex(ds::TensorDataset, idxs::Union{Vector{<:Integer}, AbstractRange})
img = Image.(eachslice(convert2image(ds.dataset, idxs); dims=3))
y = onehotbatch(ds.dataset.targets[idxs], 0:9)
return stack(parent itemdata Base.Fix1(apply, ds.transform), img), y
end

function get_cifar10_dataloaders(batchsize; kwargs...)
cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std = (0.2471, 0.2435, 0.2616)

train_transform = RandomResizeCrop((32, 32)) |>
Maybe(FlipX{2}()) |>
ImageToTensor() |>
Normalize(cifar10_mean, cifar10_std)

test_transform = ImageToTensor() |> Normalize(cifar10_mean, cifar10_std)

trainset = TensorDataset(CIFAR10(:train), train_transform)
trainloader = DataLoader(trainset; batchsize, shuffle=true, kwargs...)

testset = TensorDataset(CIFAR10(:test), test_transform)
testloader = DataLoader(testset; batchsize, shuffle=false, kwargs...)

return trainloader, testloader
end

function accuracy(model, ps, st, dataloader)
total_correct, total = 0, 0
cdev = cpu_device()
for (x, y) in dataloader
target_class = onecold(cdev(y))
predicted_class = onecold(cdev(first(model(x, ps, st))))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end

function get_accelerator_device(backend::String)
if backend == "gpu_if_available"
return gpu_device()
elseif backend == "gpu"
return gpu_device(; force=true)
elseif backend == "reactant"
return reactant_device(; force=true)
elseif backend == "cpu"
return cpu_device()
else
error("Invalid backend: $(backend). Valid Options are: `gpu_if_available`, `gpu`, \
`reactant`, and `cpu`.")
end
end

function train_model(
model, opt, scheduler=nothing;
backend::String, batchsize::Int=512, seed::Int=1234, epochs::Int=25
)
rng = Random.default_rng()
Random.seed!(rng, seed)

accelerator_device = get_accelerator_device(backend)
kwargs = accelerator_device isa ReactantDevice ? (; partial=false) : ()
trainloader, testloader = get_cifar10_dataloaders(batchsize; kwargs...) |>
accelerator_device

ps, st = Lux.setup(rng, model) |> accelerator_device

train_state = Training.TrainState(model, ps, st, opt)

adtype = backend == "reactant" ? AutoEnzyme() : AutoZygote()

if backend == "reactant"
x_ra = rand(rng, Float32, size(first(trainloader)[1])) |> accelerator_device
@printf "[Info] Compiling model with Reactant.jl\n"
st_test = Lux.testmode(st)
model_compiled = Reactant.compile(model, (x_ra, ps, st_test))
@printf "[Info] Model compiled!\n"
else
model_compiled = model
end

loss_fn = CrossEntropyLoss(; logits=Val(true))

pt = ProgressTable(;
header=[
"Epoch", "Learning Rate", "Train Accuracy (%)", "Test Accuracy (%)", "Time (s)"
],
widths=[24, 24, 24, 24, 24],
format=["%3d", "%.6f", "%.6f", "%.6f", "%.6f"],
color=[:normal, :normal, :blue, :blue, :normal],
border=true,
alignment=[:center, :center, :center, :center, :center]
)

@printf "[Info] Training model\n"
initialize(pt)

for epoch in 1:epochs
stime = time()
lr = 0
for (i, (x, y)) in enumerate(trainloader)
if scheduler !== nothing
lr = scheduler((epoch - 1) + (i + 1) / length(trainloader))
train_state = Optimisers.adjust!(train_state, lr)
end
(_, loss, _, train_state) = Training.single_train_step!(
adtype, loss_fn, (x, y), train_state
)
isnan(loss) && error("NaN loss encountered!")
end
ttime = time() - stime

train_acc = accuracy(
model_compiled, train_state.parameters,
Lux.testmode(train_state.states), trainloader
) * 100
test_acc = accuracy(
model_compiled, train_state.parameters,
Lux.testmode(train_state.states), testloader
) * 100

scheduler === nothing && (lr = NaN32)
next(pt, [epoch, lr, train_acc, test_acc, ttime])
end

finalize(pt)
@printf "[Info] Finished training\n"
end
50 changes: 50 additions & 0 deletions examples/CIFAR10/conv_mixer.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
using Comonicon, Interpolations, Lux, Optimisers, Printf, Random, Statistics, Zygote, Enzyme

include("common.jl")

function ConvMixer(; dim, depth, kernel_size=5, patch_size=2)
#! format: off
return Chain(
Conv((patch_size, patch_size), 3 => dim, gelu; stride=patch_size),
BatchNorm(dim),
[
Chain(
SkipConnection(
Chain(
Conv(
(kernel_size, kernel_size), dim => dim, gelu;
groups=dim, pad=SamePad()
),
BatchNorm(dim)
),
+
),
Conv((1, 1), dim => dim, gelu),
BatchNorm(dim)
)
for _ in 1:depth
]...,
GlobalMeanPool(),
FlattenLayer(),
Dense(dim => 10)
)
#! format: on
end

Comonicon.@main function main(;
batchsize::Int=512, hidden_dim::Int=256, depth::Int=8,
patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=0.0001,
clip_norm::Bool=false, seed::Int=1234, epochs::Int=25, lr_max::Float64=0.05,
backend::String="reactant"
)
model = ConvMixer(; dim=hidden_dim, depth, kernel_size, patch_size)

opt = AdamW(; eta=lr_max, lambda=weight_decay)
clip_norm && (opt = OptimiserChain(ClipNorm(), opt))

lr_schedule = linear_interpolation(
[0, epochs * 2 ÷ 5, epochs * 4 ÷ 5, epochs + 1], [0, lr_max, lr_max / 20, 0]
)

return train_model(model, opt, lr_schedule; backend, batchsize, seed, epochs)
end
3 changes: 3 additions & 0 deletions examples/CIFAR10/mlp_mixer.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
using Comonicon, Lux, Optimisers, Printf, Random, Statistics, Zygote, Enzyme

include("common.jl")
34 changes: 34 additions & 0 deletions examples/CIFAR10/simple_cnn.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
using Comonicon, Lux, Optimisers, Printf, Random, Statistics, Zygote, Enzyme

include("common.jl")

function SimpleCNN()
return Chain(
Conv((3, 3), 3 => 16, gelu; stride=2, pad=1),
BatchNorm(16),
Conv((3, 3), 16 => 32, gelu; stride=2, pad=1),
BatchNorm(32),
Conv((3, 3), 32 => 64, gelu; stride=2, pad=1),
BatchNorm(64),
Conv((3, 3), 64 => 128, gelu; stride=2, pad=1),
BatchNorm(128),
GlobalMeanPool(),
FlattenLayer(),
Dense(128 => 64, gelu),
BatchNorm(64),
Dense(64 => 10)
)
end

Comonicon.@main function main(;
batchsize::Int=512, weight_decay::Float64=0.0001,
clip_norm::Bool=false, seed::Int=1234, epochs::Int=50, lr::Float64=0.003,
backend::String="reactant"
)
model = SimpleCNN()

opt = AdamW(; eta=lr, lambda=weight_decay)
clip_norm && (opt = OptimiserChain(ClipNorm(), opt))

return train_model(model, opt, nothing; backend, batchsize, seed, epochs)
end
Loading

0 comments on commit 00c3206

Please sign in to comment.