Skip to content

Commit 13a65be

Browse files
Merge pull request #1874 from FluxML/cl/mlutils
use MLUtils
2 parents 1f3915d + bdbcaaa commit 13a65be

File tree

12 files changed

+46
-420
lines changed

12 files changed

+46
-420
lines changed

Diff for: NEWS.md

+2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ been removed in favour of MLDatasets.jl.
77
* `flatten` is not exported anymore due to clash with Iterators.flatten.
88
* Remove Juno.jl progress bar support as it is now obsolete.
99
* `Dropout` gained improved compatibility with Int and Complex arrays and is now twice-differentiable.
10+
* Many utily functions and the `DataLoader` are [now provided by MLUtils.jl](https://github.com/FluxML/Flux.jl/pull/1874).
11+
* The DataLoader is now compatible with generic dataset types implementing `MLUtils.numobs` and `MLUtils.getobs`.
1012

1113
## v0.12.10
1214
* `Dropout`/`AlphaDropout` now supports [user-specified RNGs](https://github.com/FluxML/Flux.jl/pull/1838)

Diff for: Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
88
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
99
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1010
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
11+
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1112
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1213
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1314
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
@@ -25,6 +26,7 @@ Adapt = "3.0"
2526
ArrayInterface = "3.1, 4"
2627
CUDA = "3"
2728
Functors = "0.2.1"
29+
MLUtils = "0.1.4"
2830
MacroTools = "0.5"
2931
NNlib = "0.8.2"
3032
NNlibCUDA = "0.2"

Diff for: docs/src/models/nnlib.md

+4-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,10 @@ NNlib.gather
7676
NNlib.gather!
7777
NNlib.scatter
7878
NNlib.scatter!
79+
```
80+
81+
## Miscellaneous
7982

80-
## Utilities
83+
```@docs
8184
NNlib.logsumexp
8285
```

Diff for: docs/src/utilities.md

+10-8
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,17 @@ callback functions.
77

88
## Working with Data
99

10+
Utilities for data processing are provided by [MLUtils.jl](https://github.com/JuliaML/MLUtils.jl). Below is a non-exhaustive list.
11+
1012
```@docs
11-
Flux.unsqueeze
12-
Flux.stack
13-
Flux.unstack
14-
Flux.chunk
15-
Flux.frequencies
16-
Flux.batch
17-
Flux.unbatch
18-
Flux.batchseq
13+
MLUtils.unsqueeze
14+
MLUtils.stack
15+
MLUtils.unstack
16+
MLUtils.chunk
17+
MLUtils.group_counts
18+
MLUtils.batch
19+
MLUtils.unbatch
20+
MLUtils.batchseq
1921
Base.rpad(v::AbstractVector, n::Integer, p)
2022
```
2123

Diff for: src/Flux.jl

+4
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ using Statistics, Random, LinearAlgebra
77
using Zygote, MacroTools, ProgressLogging, Reexport
88
using MacroTools: @forward
99
@reexport using NNlib
10+
11+
using MLUtils
12+
1013
using Zygote: Params, @adjoint, gradient, pullback, @nograd
1114
export gradient
1215

@@ -50,6 +53,7 @@ include("outputsize.jl")
5053
include("data/Data.jl")
5154
using .Data
5255

56+
5357
include("losses/Losses.jl")
5458
using .Losses # TODO: stop importing Losses in Flux's namespace in v0.12
5559

Diff for: src/data/Data.jl

+1-4
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
module Data
22

3-
using Random: shuffle!
4-
using Base: @propagate_inbounds
5-
6-
include("dataloader.jl")
3+
using MLUtils
74
export DataLoader
85

96
end#module

Diff for: src/data/dataloader.jl

-121
This file was deleted.

Diff for: src/data/tree.jl

-35
This file was deleted.

Diff for: src/deprecations.jl

+1
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ ones32(::Type, dims...) = throw(ArgumentError("Flux.ones32 is always Float32, us
1616
zeros32(::Type, dims...) = throw(ArgumentError("Flux.zeros32 is always Float32, use Base.zeros to specify the element type"))
1717

1818
# v0.13 deprecations
19+
@deprecate frequencies(xs) group_counts(xs)

Diff for: src/onehot.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ Base.hcat(x::T, xs::T...) where {L, T <: OneHotLike{<:Any, L, <:Any, 2}} =
8787
Base.hcat(x::T, xs::T...) where {L, T <: OneHotLike{<:Any, L, <:Any, 1}} =
8888
OneHotMatrix(reduce(vcat, _indices.(xs); init = _indices(x)), L)
8989

90-
batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotArray(_indices.(xs), L)
90+
MLUtils.batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotMatrix(_indices.(xs), L)
9191

9292
Adapt.adapt_structure(T, x::OneHotArray{<:Any, L}) where L = OneHotArray(adapt(T, _indices(x)), L)
9393

0 commit comments

Comments
 (0)