|
12 | 12 | @test transpose(X) * b2 == transpose(X) * Array(b2)
|
13 | 13 | @test transpose(v) * b2 == transpose(v) * Array(b2)
|
14 | 14 | @test_throws DimensionMismatch A*b2
|
| 15 | + |
| 16 | + # in-place with mul! |
| 17 | + c1 = fill(NaN, 3) |
| 18 | + @test mul!(c1, A, b1) == A * Array(b1) |
| 19 | + @test c1 == A * Array(b1) |
| 20 | + @test mul!(c1, transpose(A), b1) == transpose(A) * Array(b1) |
| 21 | + @test mul!(zeros(3,1), A, b1) == reshape(A * b1, 3,1) |
| 22 | + @test mul!([NaN], transpose(v), b2) == mul!([NaN], transpose(v), Array(b2)) |
| 23 | + |
| 24 | + @test_throws DimensionMismatch mul!(zeros(5), A, b1) |
| 25 | + @test_throws DimensionMismatch mul!(c1, X, b1) |
15 | 26 | end
|
16 | 27 |
|
17 | 28 | @testset "AbstractMatrix-OneHotMatrix multiplication" begin
|
|
22 | 33 | b2 = OneHotMatrix([2, 4, 1, 3], 5)
|
23 | 34 | b3 = OneHotMatrix([1, 1, 2], 4)
|
24 | 35 | b4 = reshape(OneHotMatrix([1 2 3; 2 2 1], 3), 3, :)
|
| 36 | + @test OneHotArrays._isonehot(b4) |
25 | 37 | b5 = reshape(b4, 6, :)
|
26 | 38 | b6 = reshape(OneHotMatrix([1 2 2; 2 2 1], 2), 3, :)
|
| 39 | + @test !OneHotArrays._isonehot(b6) |
27 | 40 | b7 = reshape(OneHotMatrix([1 2 3; 1 2 3], 3), 6, :)
|
28 | 41 |
|
29 | 42 | @test A * b1 == A[:,[1, 1, 2, 2]]
|
|
41 | 54 | @test_throws DimensionMismatch A*b2'
|
42 | 55 | @test_throws DimensionMismatch A*b6'
|
43 | 56 | @test_throws DimensionMismatch A*b7
|
| 57 | + |
| 58 | + # in-place with mul! |
| 59 | + c1 = fill(NaN, 3, 4) |
| 60 | + @test mul!(c1, A, b1) == A * b1 |
| 61 | + @test c1 == A * b1 |
| 62 | + |
| 63 | + c4 = fill(NaN, 3, 6) |
| 64 | + @test mul!(c4, A, b4) == A * b4 # b4 is reshaped but still one-hot |
| 65 | + @test mul!(c4, A', b4) == A' * b4 |
| 66 | + c6 = fill(NaN, 3, 4) |
| 67 | + @test mul!(c6, A, b6) == A * b6 # b4 is reshaped and not one-hot |
| 68 | + @test mul!(c6, A', b6) == A' * b6 |
| 69 | + |
| 70 | + @test_throws DimensionMismatch mul!(c1, A, b2) |
| 71 | + @test_throws DimensionMismatch mul!(c1, A, b4) |
| 72 | + @test_throws DimensionMismatch mul!(c4, A, b1) |
| 73 | + @test_throws DimensionMismatch mul!(zeros(10, 3), A, b1) |
44 | 74 | end
|
0 commit comments