Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New correct model, gpu functionality, overfit of a single image #12

Open
wants to merge 19 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 0 additions & 16 deletions .github/workflows/CompatHelper.yml

This file was deleted.

11 changes: 0 additions & 11 deletions .github/workflows/TagBot.yml

This file was deleted.

40 changes: 0 additions & 40 deletions .github/workflows/ci.yml

This file was deleted.

2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
Manifest.toml
test_*.png
test/data/*
9 changes: 7 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,18 @@ authors = ["dhairyagandhi <[email protected]>"]
version = "0.1.0"

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
CUDAapi = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534"
ImageIO = "82e4d734-157c-48bb-816b-45c225c6df19"
ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

Expand All @@ -19,7 +25,6 @@ FileIO = "1"
Flux = "0.10, 0.11"
ImageCore = "0.8"
ImageTransformations = "0.8"
Reexport = "0"
StatsBase = "0"
julia = "1.3"

Expand Down
89 changes: 13 additions & 76 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,82 +1,19 @@
# UNet.jl

[![Actions Status](https://github.com/dhairyagandhi96/UNet.jl/workflows/CI/badge.svg)](https://github.com/dhairyagandhi96/UNet.jl/actions)
This pacakge provides a generic UNet implemented in Julia using Flux. Originally based on https://github.com/DhairyaLGandhi/UNet.jl but heavily modified.

This pacakge provides a generic UNet implemented in Julia.

The package is built on top of Flux.jl, and therefore can be extended as needed

```julia
julia> u = Unet()
UNet:
ConvDown(64, 64)
ConvDown(128, 128)
ConvDown(256, 256)
ConvDown(512, 512)


UNetConvBlock(1, 3)
UNetConvBlock(3, 64)
UNetConvBlock(64, 128)
UNetConvBlock(128, 256)
UNetConvBlock(256, 512)
UNetConvBlock(512, 1024)
UNetConvBlock(1024, 1024)


UNetUpBlock(1024, 512)
UNetUpBlock(1024, 256)
UNetUpBlock(512, 128)
UNetUpBlock(256, 64)
```

To default input channel dimension is expected to be `1` ie. grayscale. To support different channel images, you can pass the `channels` to `Unet`.

```julia
julia> u = Unet(3) # for RGB images
```

The input size can be any power of two sized batch. Something like `(256,256, channels, batch_size)`.

## GPU Support

To train the model on UNet, it is as simple as calling `gpu` on the model.

```julia
julia> u = gpu(u);

julia> r = gpu(rand(Float32, 256, 256, 1, 1));

julia> size(u(r))
(256, 256, 1, 1)
```

## Training

Training UNet is a breeze too.

You can define your own loss function, or use a provided Binary Cross Entropy implementation via `bce`.

```julia
julia> w = rand(Float32, 256,256,1,1);

julia> w′ = rand(Float32, 256,256,1,1);

julia> function loss(x, y)
op = clamp.(u(x), 0.001f0, 1.f0)
mean(bce(op, y))
end
loss (generic function with 1 method)

julia> using Base.Iterators

julia> rep = Iterators.repeated((w, w′), 10);
## Further Reading
The package is an implementation of the [paper](https://arxiv.org/pdf/1505.04597.pdf), and all credits of the model itself go to the respective authors.

julia> opt = Momentum()
Momentum(0.01, 0.9, IdDict{Any,Any}())
## Usage

julia> Flux.train!(loss, Flux.params(u), rep, opt);
```
See runtests.jl to see how to overfit a single image, also train.jl for a generic training script.

## Further Reading
The package is an implementation of the [paper](https://arxiv.org/pdf/1505.04597.pdf), and all credits of the model itself go to the respective authors.
* Input:
* ![GitHub Logo](/test/testdata/input.png)
* Target:
* ![GitHub Logo](/test/testdata/target.png)
* Prediction:
* ![GitHub Logo](/test/testdata/prediction.png)
* Training:
* ![GitHub Logo](/test/testdata/training.gif)
21 changes: 13 additions & 8 deletions src/UNet.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,27 @@
module UNet

export Unet, bce, load_img, load_batch

using Reexport

using StatsBase
using Flux
using Flux: @functor
using Flux.Data: DataLoader
using Flux: logitcrossentropy, dice_coeff_loss

using Images
using ImageCore
using ImageTransformations: imresize
using FileIO
using Distributions: Normal

@reexport using Statistics
@reexport using Flux, Flux.Zygote, Flux.Optimise
using Serialization
using ForwardDiff
using Parameters: @with_kw
using CUDAapi
using CUDA

include("utils.jl")
include("defaults.jl")
include("dataloader.jl")
include("model.jl")
include("train.jl")

export Unet, train

end # module
Loading