Skip to content

use MLUtils #1874

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ been removed in favour of MLDatasets.jl.
* `flatten` is not exported anymore due to clash with Iterators.flatten.
* Remove Juno.jl progress bar support as it is now obsolete.
* `Dropout` gained improved compatibility with Int and Complex arrays and is now twice-differentiable.
* Many utily functions and the `DataLoader` are [now provided by MLUtils.jl](https://github.com/FluxML/Flux.jl/pull/1874).
* The DataLoader is now compatible with generic dataset types implementing `MLUtils.numobs` and `MLUtils.getobs`.

## v0.12.10
* `Dropout`/`AlphaDropout` now supports [user-specified RNGs](https://github.com/FluxML/Flux.jl/pull/1838)
Expand Down
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
Expand All @@ -25,6 +26,7 @@ Adapt = "3.0"
ArrayInterface = "3.1, 4"
CUDA = "3"
Functors = "0.2.1"
MLUtils = "0.1.4"
MacroTools = "0.5"
NNlib = "0.8.2"
NNlibCUDA = "0.2"
Expand Down
5 changes: 4 additions & 1 deletion docs/src/models/nnlib.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,10 @@ NNlib.gather
NNlib.gather!
NNlib.scatter
NNlib.scatter!
```

## Miscellaneous

## Utilities
```@docs
NNlib.logsumexp
```
18 changes: 10 additions & 8 deletions docs/src/utilities.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,17 @@ callback functions.

## Working with Data

Utilities for data processing are provided by [MLUtils.jl](https://github.com/JuliaML/MLUtils.jl). Below is a non-exhaustive list.

```@docs
Flux.unsqueeze
Flux.stack
Flux.unstack
Flux.chunk
Flux.frequencies
Flux.batch
Flux.unbatch
Flux.batchseq
MLUtils.unsqueeze
MLUtils.stack
MLUtils.unstack
MLUtils.chunk
MLUtils.group_counts
MLUtils.batch
MLUtils.unbatch
MLUtils.batchseq
Base.rpad(v::AbstractVector, n::Integer, p)
```

Expand Down
4 changes: 4 additions & 0 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ using Statistics, Random, LinearAlgebra
using Zygote, MacroTools, ProgressLogging, Reexport
using MacroTools: @forward
@reexport using NNlib

using MLUtils

using Zygote: Params, @adjoint, gradient, pullback, @nograd
export gradient

Expand Down Expand Up @@ -50,6 +53,7 @@ include("outputsize.jl")
include("data/Data.jl")
using .Data


include("losses/Losses.jl")
using .Losses # TODO: stop importing Losses in Flux's namespace in v0.12

Expand Down
5 changes: 1 addition & 4 deletions src/data/Data.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
module Data

using Random: shuffle!
using Base: @propagate_inbounds

include("dataloader.jl")
using MLUtils
export DataLoader

end#module
121 changes: 0 additions & 121 deletions src/data/dataloader.jl

This file was deleted.

35 changes: 0 additions & 35 deletions src/data/tree.jl

This file was deleted.

1 change: 1 addition & 0 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ ones32(::Type, dims...) = throw(ArgumentError("Flux.ones32 is always Float32, us
zeros32(::Type, dims...) = throw(ArgumentError("Flux.zeros32 is always Float32, use Base.zeros to specify the element type"))

# v0.13 deprecations
@deprecate frequencies(xs) group_counts(xs)
2 changes: 1 addition & 1 deletion src/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ Base.hcat(x::T, xs::T...) where {L, T <: OneHotLike{<:Any, L, <:Any, 2}} =
Base.hcat(x::T, xs::T...) where {L, T <: OneHotLike{<:Any, L, <:Any, 1}} =
OneHotMatrix(reduce(vcat, _indices.(xs); init = _indices(x)), L)

batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotArray(_indices.(xs), L)
MLUtils.batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotMatrix(_indices.(xs), L)

Adapt.adapt_structure(T, x::OneHotArray{<:Any, L}) where L = OneHotArray(adapt(T, _indices(x)), L)

Expand Down
Loading