Skip to content

Commit 19d4275

Browse files
authored
Merge pull request #276 from FluxML/double-trouble
Force `Float32` as type presented to Flux chains
2 parents 945016d + bc805e4 commit 19d4275

12 files changed

+84
-68
lines changed

Project.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ julia = "1.9"
3333
[extras]
3434
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3535
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
36+
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
3637
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
3738
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
3839
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
@@ -42,4 +43,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4243
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
4344

4445
[targets]
45-
test = ["CUDA", "cuDNN", "LinearAlgebra", "MLJBase", "Random", "StableRNGs", "StatisticalMeasures", "StatsBase", "Test"]
46+
test = ["CUDA", "cuDNN", "LinearAlgebra", "Logging", "MLJBase", "Random", "StableRNGs", "StatisticalMeasures", "StatsBase", "Test"]

src/core.jl

+11-4
Original file line numberDiff line numberDiff line change
@@ -276,15 +276,22 @@ input `X` and target `y` in the form required by
276276
by `model.batch_size`.)
277277
278278
"""
279-
function collate(model, X, y)
279+
function collate(model, X, y, verbosity)
280280
row_batches = Base.Iterators.partition(1:nrows(y), model.batch_size)
281-
Xmatrix = reformat(X)
281+
Xmatrix = _f32(reformat(X), verbosity)
282282
ymatrix = reformat(y)
283283
return [_get(Xmatrix, b) for b in row_batches], [_get(ymatrix, b) for b in row_batches]
284284
end
285-
function collate(model::NeuralNetworkBinaryClassifier, X, y)
285+
function collate(model::NeuralNetworkBinaryClassifier, X, y, verbosity)
286286
row_batches = Base.Iterators.partition(1:nrows(y), model.batch_size)
287-
Xmatrix = reformat(X)
287+
Xmatrix = _f32(reformat(X), verbosity)
288288
yvec = (y .== classes(y)[2])' # convert to boolean
289289
return [_get(Xmatrix, b) for b in row_batches], [_get(yvec, b) for b in row_batches]
290290
end
291+
292+
_f32(x::AbstractArray{Float32}, verbosity) = x
293+
function _f32(x::AbstractArray, verbosity)
294+
verbosity > 0 && @info "MLJFlux: converting input data to Float32"
295+
return Float32.(x)
296+
end
297+

src/encoders.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ function ordinal_encoder_fit(X; featinds)
1919
feat_col = Tables.getcolumn(Tables.columns(X), i)
2020
feat_levels = levels(feat_col)
2121
# Check if feat levels is already ordinal encoded in which case we skip
22-
(Set([float(i) for i in 1:length(feat_levels)]) == Set(feat_levels)) && continue
22+
(Set([Float32(i) for i in 1:length(feat_levels)]) == Set(feat_levels)) && continue
2323
# Compute the dict using the given feature_mapper function
2424
mapping_matrix[i] =
25-
Dict{Any, AbstractFloat}(
26-
value => float(index) for (index, value) in enumerate(feat_levels)
25+
Dict{eltype(feat_levels), Float32}(
26+
value => Float32(index) for (index, value) in enumerate(feat_levels)
2727
)
2828
end
2929
return mapping_matrix
@@ -67,7 +67,7 @@ function ordinal_encoder_transform(X, mapping_matrix)
6767
test_levels = levels(col)
6868
check_unkown_levels(train_levels, test_levels)
6969
level2scalar = mapping_matrix[ind]
70-
new_col = recode(col, level2scalar...)
70+
new_col = recode(unwrap.(col), level2scalar...)
7171
push!(new_feats, new_col)
7272
else
7373
push!(new_feats, col)

src/entity_embedding.jl

+1-2
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,14 @@ julia> output = embedder(batch)
3636
```
3737
""" # 1. Define layer struct to hold parameters
3838
struct EntityEmbedder{A1 <: AbstractVector, A2 <: AbstractVector, I <: Integer}
39-
4039
embedders::A1
4140
modifiers::A2 # applied on the input before passing it to the embedder
4241
numfeats::I
4342
end
4443

4544
# 2. Define the forward pass (i.e., calling an instance of the layer)
4645
(m::EntityEmbedder)(x) =
47-
vcat([m.embedders[i](m.modifiers[i](x, i)) for i in 1:m.numfeats]...)
46+
(vcat([m.embedders[i](m.modifiers[i](x, i)) for i in 1:m.numfeats]...))
4847

4948
# 3. Define the constructor which initializes the parameters and returns the instance
5049
function EntityEmbedder(entityprops, numfeats; init = Flux.randn32)

src/mlj_model_interface.jl

+12-12
Original file line numberDiff line numberDiff line change
@@ -66,35 +66,35 @@ function MLJModelInterface.fit(model::MLJFluxModel,
6666
X,
6767
y)
6868
# GPU and rng related variables
69-
move = Mover(model.acceleration)
69+
move = MLJFlux.Mover(model.acceleration)
7070
rng = true_rng(model)
7171

7272
# Get input properties
7373
shape = MLJFlux.shape(model, X, y)
74-
cat_inds = get_cat_inds(X)
74+
cat_inds = MLJFlux.get_cat_inds(X)
7575
pure_continuous_input = isempty(cat_inds)
7676

7777
# Decide whether to enable entity embeddings (e.g., ImageClassifier won't)
78-
enable_entity_embs = is_embedding_enabled(model) && !pure_continuous_input
78+
enable_entity_embs = MLJFlux.is_embedding_enabled(model) && !pure_continuous_input
7979

8080
# Prepare entity embeddings inputs and encode X if entity embeddings enabled
8181
featnames = []
8282
if enable_entity_embs
83-
X = convert_to_table(X)
83+
X = MLJFlux.convert_to_table(X)
8484
featnames = Tables.schema(X).names
8585
end
8686

87-
# entityprops is (index = cat_inds[i], levels = num_levels[i], newdim = newdims[i])
87+
# entityprops is (index = cat_inds[i], levels = num_levels[i], newdim = newdims[i])
8888
# for each categorical feature
8989
default_embedding_dims = enable_entity_embs ? model.embedding_dims : Dict{Symbol, Real}()
9090
entityprops, entityemb_output_dim =
91-
prepare_entityembs(X, featnames, cat_inds, default_embedding_dims)
92-
X, ordinal_mappings = ordinal_encoder_fit_transform(X; featinds = cat_inds)
91+
MLJFlux.prepare_entityembs(X, featnames, cat_inds, default_embedding_dims)
92+
X, ordinal_mappings = MLJFlux.ordinal_encoder_fit_transform(X; featinds = cat_inds)
9393

9494
## Construct model chain
9595
chain =
9696
(!enable_entity_embs) ? construct_model_chain(model, rng, shape, move) :
97-
construct_model_chain_with_entityembs(
97+
MLJFlux.construct_model_chain_with_entityembs(
9898
model,
9999
rng,
100100
shape,
@@ -103,8 +103,8 @@ function MLJModelInterface.fit(model::MLJFluxModel,
103103
entityemb_output_dim,
104104
)
105105

106-
# Format data as needed by Flux and move to GPU
107-
data = move.(collate(model, X, y))
106+
# Format data as needed by Flux and move to GPU
107+
data = move.(MLJFlux.collate(model, X, y, verbosity))
108108

109109
# Test chain works (as it may be custom)
110110
x = data[1][1]
@@ -143,7 +143,7 @@ function MLJModelInterface.fit(model::MLJFluxModel,
143143
featnames,
144144
)
145145

146-
# Prepare fitresult
146+
# Prepare fitresult
147147
fitresult =
148148
MLJFlux.fitresult(model, Flux.cpu(chain), y, ordinal_mappings, embedding_matrices)
149149

@@ -216,7 +216,7 @@ function MLJModelInterface.update(model::MLJFluxModel,
216216
chain = construct_model_chain(model, rng, shape, move)
217217
end
218218
# reset `optimiser_state`:
219-
data = move.(collate(model, X, y))
219+
data = move.(collate(model, X, y, verbosity))
220220
regularized_optimiser, optimiser_state =
221221
prepare_optimiser(data, model, chain)
222222
epochs = model.epochs

test/classifier.jl

+7-7
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,17 @@ seed!(1234)
44
N = 300
55
Xm = MLJBase.table(randn(Float32, N, 5)); # purely numeric
66
X = (; Tables.columntable(Xm)...,
7-
Column1 = repeat([1.0, 2.0, 3.0, 4.0, 5.0], Int(N / 5)),
7+
Column1 = repeat(Float32[1.0, 2.0, 3.0, 4.0, 5.0], Int(N / 5)),
88
Column2 = categorical(repeat(['a', 'b', 'c', 'd', 'e'], Int(N / 5))),
99
Column3 = categorical(repeat(["b", "c", "d", "f", "f"], Int(N / 5)), ordered = true),
10-
Column4 = repeat([1.0, 2.0, 3.0, 4.0, 5.0], Int(N / 5)),
11-
Column5 = randn(N),
10+
Column4 = repeat(Float32[1.0, 2.0, 3.0, 4.0, 5.0], Int(N / 5)),
11+
Column5 = randn(Float32, N),
1212
Column6 = categorical(
1313
repeat(["group1", "group1", "group2", "group2", "group3"], Int(N / 5)),
1414
),
1515
)
1616

17-
18-
ycont = 2 * X.x1 - X.x3 + 0.1 * rand(N)
17+
ycont = 2 * X.x1 - X.x3 + 0.1 * rand(Float32, N)
1918
m, M = minimum(ycont), maximum(ycont)
2019
_, a, b, _ = range(m, stop = M, length = 4) |> collect
2120
y = map(ycont) do η
@@ -111,7 +110,8 @@ end
111110

112111
# check different resources (CPU1, CUDALibs, etc)) give about the same loss:
113112
reference = losses[1]
114-
@test all(x -> abs(x - reference) / reference < 1e-4, losses[2:end])
113+
println("losses for each resource: $losses")
114+
@test all(x -> abs(x - reference) / reference < 0.03, losses[2:end])
115115

116116

117117
# # NEURAL NETWORK BINARY CLASSIFIER
@@ -126,7 +126,7 @@ end
126126
seed!(1234)
127127
N = 300
128128
X = MLJBase.table(rand(Float32, N, 4));
129-
ycont = 2 * X.x1 - X.x3 + 0.1 * rand(N)
129+
ycont = Float32.(2 * X.x1 - X.x3 + 0.1 * rand(N))
130130
m, M = minimum(ycont), maximum(ycont)
131131
_, a, _ = range(m, stop = M, length = 3) |> collect
132132
y = map(ycont) do η

test/core.jl

+12-7
Original file line numberDiff line numberDiff line change
@@ -14,23 +14,28 @@ rowvec(y::Vector) = reshape(y, 1, length(y))
1414
end
1515

1616
@testset "collate" begin
17-
# NeuralNetworRegressor:
18-
Xmatrix = broadcast(x->round(x, sigdigits=2), rand(stable_rng, 10, 3))
17+
Xmatrix = broadcast(x->round(x, sigdigits=2), rand(stable_rng, Float32, 10, 3))
18+
Xmat_f64 = Float64.(Xmatrix)
1919
# convert to a column table:
2020
X = MLJBase.table(Xmatrix)
21+
X_64 = MLJBase.table(Xmat_f64)
2122

23+
# NeuralNetworRegressor:
2224
y = rand(stable_rng, Float32, 10)
2325
model = MLJFlux.NeuralNetworkRegressor()
2426
model.batch_size= 3
25-
@test MLJFlux.collate(model, X, y) ==
27+
@test MLJFlux.collate(model, X, y, 1) == MLJFlux.collate(model, X_64, y, 1) ==
2628
([Xmatrix'[:,1:3], Xmatrix'[:,4:6], Xmatrix'[:,7:9], Xmatrix'[:,10:10]],
2729
rowvec.([y[1:3], y[4:6], y[7:9], y[10:10]]))
30+
@test_logs (:info,) MLJFlux.collate(model, X_64, y, 1)
31+
@test_logs min_level=Logging.Info MLJFlux.collate(model, X, y, 1)
32+
@test_logs min_level=Logging.Info MLJFlux.collate(model, X, y, 0)
2833

2934
# NeuralNetworClassifier:
3035
y = categorical(['a', 'b', 'a', 'a', 'b', 'a', 'a', 'a', 'b', 'a'])
3136
model = MLJFlux.NeuralNetworkClassifier()
3237
model.batch_size = 3
33-
data = MLJFlux.collate(model, X, y)
38+
data = MLJFlux.collate(model, X, y, 1)
3439

3540
@test data == ([Xmatrix'[:,1:3], Xmatrix'[:,4:6],
3641
Xmatrix'[:,7:9], Xmatrix'[:,10:10]],
@@ -42,13 +47,13 @@ end
4247
y = MLJBase.table(ymatrix) # a rowaccess table
4348
model = MLJFlux.NeuralNetworkRegressor()
4449
model.batch_size= 3
45-
@test MLJFlux.collate(model, X, y) ==
50+
@test MLJFlux.collate(model, X, y, 1) ==
4651
([Xmatrix'[:,1:3], Xmatrix'[:,4:6], Xmatrix'[:,7:9], Xmatrix'[:,10:10]],
4752
rowvec.([ymatrix'[:,1:3], ymatrix'[:,4:6], ymatrix'[:,7:9],
4853
ymatrix'[:,10:10]]))
4954

5055
y = Tables.columntable(y) # try a columnaccess table
51-
@test MLJFlux.collate(model, X, y) ==
56+
@test MLJFlux.collate(model, X, y, 1) ==
5257
([Xmatrix'[:,1:3], Xmatrix'[:,4:6], Xmatrix'[:,7:9], Xmatrix'[:,10:10]],
5358
rowvec.([ymatrix'[:,1:3], ymatrix'[:,4:6],
5459
ymatrix'[:,7:9], ymatrix'[:,10:10]]))
@@ -58,7 +63,7 @@ end
5863
y = categorical(['a', 'b', 'a', 'a', 'b', 'a', 'a', 'a', 'b', 'a'])
5964
model = MLJFlux.ImageClassifier(batch_size=2)
6065

61-
data = MLJFlux.collate(model, Xmatrix, y)
66+
data = MLJFlux.collate(model, Xmatrix, y, 1)
6267
@test first.(data) == (Float32.(cat(Xmatrix[1], Xmatrix[2], dims=4)),
6368
rowvec.([1 0;0 1]))
6469

test/encoders.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
@test map[2] == Dict('a' => 1, 'b' => 2, 'c' => 3, 'd' => 4, 'e' => 5)
1313
@test map[3] == Dict("b" => 1, "c" => 2, "d" => 3)
1414
@test Xenc.Column1 == [1.0, 2.0, 3.0, 4.0, 5.0]
15-
@test Xenc.Column2 == [1.0, 2.0, 3.0, 4.0, 5.0]
16-
@test Xenc.Column3 == [1, 2, 3]
15+
@test Xenc.Column2 == Float32.([1.0, 2.0, 3.0, 4.0, 5.0])
16+
@test Xenc.Column3 == Float32.([1, 2, 3])
1717
@test Xenc.Column4 == [1.0, 2.0, 3.0, 4.0, 5.0]
1818

1919
X = coerce(X, :Column1 => Multiclass)

test/entity_embedding.jl

+8-7
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
"""
22
See more functional tests in entity_embedding_utils.jl and mlj_model_interface.jl
33
"""
4-
5-
batch = [
4+
batch = Float32.([
65
0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 1.1;
7-
1 2 3 4 5 6 7 8 9 10;
8-
0.9 0.1 0.4 0.5 0.3 0.7 0.8 0.9 1.0 1.1
9-
1 1 2 2 1 1 2 2 1 1
10-
]
6+
1 2 3 4 5 6 7 8 9 10;
7+
0.9 0.1 0.4 0.5 0.3 0.7 0.8 0.9 1.0 1.1;
8+
1 1 2 2 1 1 2 2 1 1
9+
])
10+
1111

1212
entityprops = [
1313
(index = 2, levels = 10, newdim = 2),
@@ -145,7 +145,8 @@ end
145145
numfeats = 4
146146
embedder = MLJFlux.EntityEmbedder(entityprops, 4)
147147
output = embedder(batch)
148-
@test output == batch
148+
@test output batch
149+
@test eltype(output) == Float32
149150
end
150151

151152

0 commit comments

Comments
 (0)