-
Notifications
You must be signed in to change notification settings - Fork 62
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
docs: add multiple CIFAR10 examples using Reactant
- Loading branch information
Showing
10 changed files
with
290 additions
and
202 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Train Vision Models on CIFAR-10 | ||
|
||
✈️ 🚗 🐦 🐈 🦌 🐕 🐸 🐎 🚢 🚚 | ||
|
||
We have the following scripts to train vision models on CIFAR-10: | ||
|
||
1. `simple_cnn.jl`: Simple CNN model with a sequence of convolutional layers. | ||
2. `mlp_mixer.jl`: MLP-Mixer model. | ||
3. `conv_mixer.jl`: ConvMixer model. | ||
|
||
To get the options for each script, run the script with the `--help` flag. | ||
|
||
> [!NOTE] | ||
> To train the model using Reactant.jl pass in `--backend=reactant` to the script. This is | ||
> the recommended approach to train the models present in this directory. | ||
## Simple CNN | ||
|
||
```bash | ||
julia --startup-file=no \ | ||
--project=. \ | ||
--threads=auto \ | ||
simple_cnn.jl \ | ||
--backend=reactant | ||
``` | ||
|
||
On a RTX 4050 6GB Laptop GPU the training takes approximately 3 mins and the final training | ||
and test accuracies are 97% and 65%, respectively. | ||
|
||
## MLP-Mixer | ||
|
||
## ConvMixer | ||
|
||
> [!NOTE] | ||
> This code has been adapted from https://github.com/locuslab/convmixer-cifar10 | ||
This is a simple ConvMixer training script for CIFAR-10. It's probably a good starting point | ||
for new experiments on small datasets. | ||
|
||
You can get around **90.0%** accuracy in just **25 epochs** by running the script with the | ||
following arguments, which trains a ConvMixer-256/8 with kernel size 5 and patch size 2. | ||
|
||
```bash | ||
julia --startup-file=no \ | ||
--project=. \ | ||
--threads=auto \ | ||
conv_mixer.jl \ | ||
--backend=reactant | ||
``` | ||
|
||
### Notes | ||
|
||
1. To match the results from the original repo, we need more augmentation strategies, that | ||
are currently not implemented in DataAugmentation.jl. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
using ConcreteStructs, DataAugmentation, ImageShow, Lux, MLDatasets, MLUtils, OneHotArrays, | ||
Printf, ProgressTables, Random | ||
using Reactant, LuxCUDA | ||
|
||
@concrete struct TensorDataset | ||
dataset | ||
transform | ||
end | ||
|
||
Base.length(ds::TensorDataset) = length(ds.dataset) | ||
|
||
function Base.getindex(ds::TensorDataset, idxs::Union{Vector{<:Integer}, AbstractRange}) | ||
img = Image.(eachslice(convert2image(ds.dataset, idxs); dims=3)) | ||
y = onehotbatch(ds.dataset.targets[idxs], 0:9) | ||
return stack(parent ∘ itemdata ∘ Base.Fix1(apply, ds.transform), img), y | ||
end | ||
|
||
function get_cifar10_dataloaders(batchsize; kwargs...) | ||
cifar10_mean = (0.4914, 0.4822, 0.4465) | ||
cifar10_std = (0.2471, 0.2435, 0.2616) | ||
|
||
train_transform = RandomResizeCrop((32, 32)) |> | ||
Maybe(FlipX{2}()) |> | ||
ImageToTensor() |> | ||
Normalize(cifar10_mean, cifar10_std) | ||
|
||
test_transform = ImageToTensor() |> Normalize(cifar10_mean, cifar10_std) | ||
|
||
trainset = TensorDataset(CIFAR10(:train), train_transform) | ||
trainloader = DataLoader(trainset; batchsize, shuffle=true, kwargs...) | ||
|
||
testset = TensorDataset(CIFAR10(:test), test_transform) | ||
testloader = DataLoader(testset; batchsize, shuffle=false, kwargs...) | ||
|
||
return trainloader, testloader | ||
end | ||
|
||
function accuracy(model, ps, st, dataloader) | ||
total_correct, total = 0, 0 | ||
cdev = cpu_device() | ||
for (x, y) in dataloader | ||
target_class = onecold(cdev(y)) | ||
predicted_class = onecold(cdev(first(model(x, ps, st)))) | ||
total_correct += sum(target_class .== predicted_class) | ||
total += length(target_class) | ||
end | ||
return total_correct / total | ||
end | ||
|
||
function get_accelerator_device(backend::String) | ||
if backend == "gpu_if_available" | ||
return gpu_device() | ||
elseif backend == "gpu" | ||
return gpu_device(; force=true) | ||
elseif backend == "reactant" | ||
return reactant_device(; force=true) | ||
elseif backend == "cpu" | ||
return cpu_device() | ||
else | ||
error("Invalid backend: $(backend). Valid Options are: `gpu_if_available`, `gpu`, \ | ||
`reactant`, and `cpu`.") | ||
end | ||
end | ||
|
||
function train_model( | ||
model, opt, scheduler=nothing; | ||
backend::String, batchsize::Int=512, seed::Int=1234, epochs::Int=25 | ||
) | ||
rng = Random.default_rng() | ||
Random.seed!(rng, seed) | ||
|
||
accelerator_device = get_accelerator_device(backend) | ||
kwargs = accelerator_device isa ReactantDevice ? (; partial=false) : () | ||
trainloader, testloader = get_cifar10_dataloaders(batchsize; kwargs...) |> | ||
accelerator_device | ||
|
||
ps, st = Lux.setup(rng, model) |> accelerator_device | ||
|
||
train_state = Training.TrainState(model, ps, st, opt) | ||
|
||
adtype = backend == "reactant" ? AutoEnzyme() : AutoZygote() | ||
|
||
if backend == "reactant" | ||
x_ra = rand(rng, Float32, size(first(trainloader)[1])) |> accelerator_device | ||
@printf "[Info] Compiling model with Reactant.jl\n" | ||
st_test = Lux.testmode(st) | ||
model_compiled = Reactant.compile(model, (x_ra, ps, st_test)) | ||
@printf "[Info] Model compiled!\n" | ||
else | ||
model_compiled = model | ||
end | ||
|
||
loss_fn = CrossEntropyLoss(; logits=Val(true)) | ||
|
||
pt = ProgressTable(; | ||
header=[ | ||
"Epoch", "Learning Rate", "Train Accuracy (%)", "Test Accuracy (%)", "Time (s)" | ||
], | ||
widths=[24, 24, 24, 24, 24], | ||
format=["%3d", "%.6f", "%.6f", "%.6f", "%.6f"], | ||
color=[:normal, :normal, :blue, :blue, :normal], | ||
border=true, | ||
alignment=[:center, :center, :center, :center, :center] | ||
) | ||
|
||
@printf "[Info] Training model\n" | ||
initialize(pt) | ||
|
||
for epoch in 1:epochs | ||
stime = time() | ||
lr = 0 | ||
for (i, (x, y)) in enumerate(trainloader) | ||
if scheduler !== nothing | ||
lr = scheduler((epoch - 1) + (i + 1) / length(trainloader)) | ||
train_state = Optimisers.adjust!(train_state, lr) | ||
end | ||
(_, loss, _, train_state) = Training.single_train_step!( | ||
adtype, loss_fn, (x, y), train_state | ||
) | ||
isnan(loss) && error("NaN loss encountered!") | ||
end | ||
ttime = time() - stime | ||
|
||
train_acc = accuracy( | ||
model_compiled, train_state.parameters, | ||
Lux.testmode(train_state.states), trainloader | ||
) * 100 | ||
test_acc = accuracy( | ||
model_compiled, train_state.parameters, | ||
Lux.testmode(train_state.states), testloader | ||
) * 100 | ||
|
||
scheduler === nothing && (lr = NaN32) | ||
next(pt, [epoch, lr, train_acc, test_acc, ttime]) | ||
end | ||
|
||
finalize(pt) | ||
@printf "[Info] Finished training\n" | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
using Comonicon, Interpolations, Lux, Optimisers, Printf, Random, Statistics, Zygote, Enzyme | ||
|
||
include("common.jl") | ||
|
||
function ConvMixer(; dim, depth, kernel_size=5, patch_size=2) | ||
#! format: off | ||
return Chain( | ||
Conv((patch_size, patch_size), 3 => dim, gelu; stride=patch_size), | ||
BatchNorm(dim), | ||
[ | ||
Chain( | ||
SkipConnection( | ||
Chain( | ||
Conv( | ||
(kernel_size, kernel_size), dim => dim, gelu; | ||
groups=dim, pad=SamePad() | ||
), | ||
BatchNorm(dim) | ||
), | ||
+ | ||
), | ||
Conv((1, 1), dim => dim, gelu), | ||
BatchNorm(dim) | ||
) | ||
for _ in 1:depth | ||
]..., | ||
GlobalMeanPool(), | ||
FlattenLayer(), | ||
Dense(dim => 10) | ||
) | ||
#! format: on | ||
end | ||
|
||
Comonicon.@main function main(; | ||
batchsize::Int=512, hidden_dim::Int=256, depth::Int=8, | ||
patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=0.0001, | ||
clip_norm::Bool=false, seed::Int=1234, epochs::Int=25, lr_max::Float64=0.05, | ||
backend::String="reactant" | ||
) | ||
model = ConvMixer(; dim=hidden_dim, depth, kernel_size, patch_size) | ||
|
||
opt = AdamW(; eta=lr_max, lambda=weight_decay) | ||
clip_norm && (opt = OptimiserChain(ClipNorm(), opt)) | ||
|
||
lr_schedule = linear_interpolation( | ||
[0, epochs * 2 ÷ 5, epochs * 4 ÷ 5, epochs + 1], [0, lr_max, lr_max / 20, 0] | ||
) | ||
|
||
return train_model(model, opt, lr_schedule; backend, batchsize, seed, epochs) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
using Comonicon, Lux, Optimisers, Printf, Random, Statistics, Zygote, Enzyme | ||
|
||
include("common.jl") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
using Comonicon, Lux, Optimisers, Printf, Random, Statistics, Zygote, Enzyme | ||
|
||
include("common.jl") | ||
|
||
function SimpleCNN() | ||
return Chain( | ||
Conv((3, 3), 3 => 16, gelu; stride=2, pad=1), | ||
BatchNorm(16), | ||
Conv((3, 3), 16 => 32, gelu; stride=2, pad=1), | ||
BatchNorm(32), | ||
Conv((3, 3), 32 => 64, gelu; stride=2, pad=1), | ||
BatchNorm(64), | ||
Conv((3, 3), 64 => 128, gelu; stride=2, pad=1), | ||
BatchNorm(128), | ||
GlobalMeanPool(), | ||
FlattenLayer(), | ||
Dense(128 => 64, gelu), | ||
BatchNorm(64), | ||
Dense(64 => 10) | ||
) | ||
end | ||
|
||
Comonicon.@main function main(; | ||
batchsize::Int=512, weight_decay::Float64=0.0001, | ||
clip_norm::Bool=false, seed::Int=1234, epochs::Int=50, lr::Float64=0.003, | ||
backend::String="reactant" | ||
) | ||
model = SimpleCNN() | ||
|
||
opt = AdamW(; eta=lr, lambda=weight_decay) | ||
clip_norm && (opt = OptimiserChain(ClipNorm(), opt)) | ||
|
||
return train_model(model, opt, nothing; backend, batchsize, seed, epochs) | ||
end |
Oops, something went wrong.