@@ -14,23 +14,28 @@ rowvec(y::Vector) = reshape(y, 1, length(y))
14
14
end
15
15
16
16
@testset " collate" begin
17
- # NeuralNetworRegressor:
18
- Xmatrix = broadcast (x -> round (x, sigdigits = 2 ), rand (stable_rng, 10 , 3 ) )
17
+ Xmatrix = broadcast (x -> round (x, sigdigits = 2 ), rand (stable_rng, Float32, 10 , 3 ))
18
+ Xmat_f64 = Float64 .(Xmatrix )
19
19
# convert to a column table:
20
20
X = MLJBase. table (Xmatrix)
21
+ X_64 = MLJBase. table (Xmat_f64)
21
22
23
+ # NeuralNetworRegressor:
22
24
y = rand (stable_rng, Float32, 10 )
23
25
model = MLJFlux. NeuralNetworkRegressor ()
24
26
model. batch_size= 3
25
- @test MLJFlux. collate (model, X, y) ==
27
+ @test MLJFlux. collate (model, X, y, 1 ) == MLJFlux . collate (model, X_64, y, 1 ) ==
26
28
([Xmatrix' [:,1 : 3 ], Xmatrix' [:,4 : 6 ], Xmatrix' [:,7 : 9 ], Xmatrix' [:,10 : 10 ]],
27
29
rowvec .([y[1 : 3 ], y[4 : 6 ], y[7 : 9 ], y[10 : 10 ]]))
30
+ @test_logs (:info ,) MLJFlux. collate (model, X_64, y, 1 )
31
+ @test_logs min_level= Logging. Info MLJFlux. collate (model, X, y, 1 )
32
+ @test_logs min_level= Logging. Info MLJFlux. collate (model, X, y, 0 )
28
33
29
34
# NeuralNetworClassifier:
30
35
y = categorical ([' a' , ' b' , ' a' , ' a' , ' b' , ' a' , ' a' , ' a' , ' b' , ' a' ])
31
36
model = MLJFlux. NeuralNetworkClassifier ()
32
37
model. batch_size = 3
33
- data = MLJFlux. collate (model, X, y)
38
+ data = MLJFlux. collate (model, X, y, 1 )
34
39
35
40
@test data == ([Xmatrix' [:,1 : 3 ], Xmatrix' [:,4 : 6 ],
36
41
Xmatrix' [:,7 : 9 ], Xmatrix' [:,10 : 10 ]],
42
47
y = MLJBase. table (ymatrix) # a rowaccess table
43
48
model = MLJFlux. NeuralNetworkRegressor ()
44
49
model. batch_size= 3
45
- @test MLJFlux. collate (model, X, y) ==
50
+ @test MLJFlux. collate (model, X, y, 1 ) ==
46
51
([Xmatrix' [:,1 : 3 ], Xmatrix' [:,4 : 6 ], Xmatrix' [:,7 : 9 ], Xmatrix' [:,10 : 10 ]],
47
52
rowvec .([ymatrix' [:,1 : 3 ], ymatrix' [:,4 : 6 ], ymatrix' [:,7 : 9 ],
48
53
ymatrix' [:,10 : 10 ]]))
49
54
50
55
y = Tables. columntable (y) # try a columnaccess table
51
- @test MLJFlux. collate (model, X, y) ==
56
+ @test MLJFlux. collate (model, X, y, 1 ) ==
52
57
([Xmatrix' [:,1 : 3 ], Xmatrix' [:,4 : 6 ], Xmatrix' [:,7 : 9 ], Xmatrix' [:,10 : 10 ]],
53
58
rowvec .([ymatrix' [:,1 : 3 ], ymatrix' [:,4 : 6 ],
54
59
ymatrix' [:,7 : 9 ], ymatrix' [:,10 : 10 ]]))
58
63
y = categorical ([' a' , ' b' , ' a' , ' a' , ' b' , ' a' , ' a' , ' a' , ' b' , ' a' ])
59
64
model = MLJFlux. ImageClassifier (batch_size= 2 )
60
65
61
- data = MLJFlux. collate (model, Xmatrix, y)
66
+ data = MLJFlux. collate (model, Xmatrix, y, 1 )
62
67
@test first .(data) == (Float32 .(cat (Xmatrix[1 ], Xmatrix[2 ], dims= 4 )),
63
68
rowvec .([1 0 ;0 1 ]))
64
69
0 commit comments