Skip to content

Commit 0b49d1b

Browse files
add missing mul! implementation (#47)
* add missing mul! implementation * remove broken from working test * bump version * add broken test * remove test_broken note that is no longer relevant * LTS is now 1.10 * bump buildkite julia version * clean up tests * check dimensions of output matrix in mul! * Update linalg.jl * Update linalg.jl * Update src/linalg.jl Co-authored-by: Michael Abbott <[email protected]> --------- Co-authored-by: Michael Abbott <[email protected]>
1 parent 82c8ed4 commit 0b49d1b

File tree

7 files changed

+76
-15
lines changed

7 files changed

+76
-15
lines changed

.buildkite/pipeline.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
steps:
2-
- label: "GPU integration with julia v1.6"
2+
- label: "GPU integration with julia v1.10"
33
plugins:
44
- JuliaCI/julia#v1:
55
# Drop default "registries" directory, so it is not persisted from execution to execution
66
# Taken from https://github.com/JuliaLang/julia/blob/v1.7.2/.buildkite/pipelines/main/platforms/package_linux.yml#L11-L12
77
persist_depot_dirs: packages,artifacts,compiled
8-
version: "1.6"
8+
version: "1.10"
99
- JuliaCI/julia-test#v1: ~
1010
agents:
1111
queue: "juliagpu"

.github/workflows/CI.yml

+6-6
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ jobs:
1818
fail-fast: false
1919
matrix:
2020
version:
21-
- '1.6'
21+
- '1.10'
2222
- '1'
2323
- 'nightly'
2424
os:
@@ -47,17 +47,17 @@ jobs:
4747

4848
- name: "Run test without coverage report"
4949
uses: julia-actions/julia-runtest@v1
50-
if: ${{ !contains(fromJson('["1", "1.6"]'), matrix.version) || matrix.os != 'ubuntu-latest' }}
50+
if: ${{ !contains(fromJson('["1", "1.10"]'), matrix.version) || matrix.os != 'ubuntu-latest' }}
5151
with:
5252
coverage: false
5353

5454
- name: "Run test with coverage report"
5555
uses: julia-actions/julia-runtest@v1
56-
if: contains(fromJson('["1", "1.6"]'), matrix.version) && matrix.os == 'ubuntu-latest'
56+
if: contains(fromJson('["1", "1.10"]'), matrix.version) && matrix.os == 'ubuntu-latest'
5757
- uses: julia-actions/julia-processcoverage@v1
58-
if: contains(fromJson('["1", "1.6"]'), matrix.version) && matrix.os == 'ubuntu-latest'
58+
if: contains(fromJson('["1", "1.10"]'), matrix.version) && matrix.os == 'ubuntu-latest'
5959
- uses: codecov/codecov-action@v3
60-
if: contains(fromJson('["1", "1.6"]'), matrix.version) && matrix.os == 'ubuntu-latest'
60+
if: contains(fromJson('["1", "1.10"]'), matrix.version) && matrix.os == 'ubuntu-latest'
6161
with:
6262
files: lcov.info
6363

@@ -68,7 +68,7 @@ jobs:
6868
- uses: actions/checkout@v3
6969
- uses: julia-actions/setup-julia@v1
7070
with:
71-
version: '1.6'
71+
version: '1.10'
7272
- run: |
7373
julia --project=docs -e '
7474
using Pkg

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "OneHotArrays"
22
uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
3-
version = "0.2.6"
3+
version = "0.2.7"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/linalg.jl

+17
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,20 @@ for wrapper in [:Adjoint, :Transpose]
3333
end
3434
end
3535
end
36+
37+
function LinearAlgebra.mul!(Y::AbstractVecOrMat, A::AbstractMatrix, B::OneHotLike)
38+
_isonehot(B) || return invoke(mul!, Tuple{AbstractArray,AbstractMatrix,AbstractMatrix}, Y, A, B)
39+
if size(A,2) size(B,1)
40+
throw(DimensionMismatch("Matrix column must correspond with the OneHot Size $(size(A,2))$(size(B,1))"))
41+
end
42+
if !(size(Y,1) == size(A,1) && size(Y,2) == size(B,2))
43+
throw(DimensionMismatch("Invalid output matrix size for multiplication of matrix sizes $(size(A)) and $(size(B))"))
44+
end
45+
idxs = _indices(B)
46+
if idxs isa Integer # occurs whe B is AbstractVector
47+
copyto!(Y, view(A, :, idxs))
48+
else
49+
NNlib.gather!(Y, A, idxs)
50+
end
51+
end
52+

test/gpu.jl

+19-5
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,23 @@ end
2323
@test (repr("text/plain", y); true)
2424

2525
gA = rand(3, 2) |> cu;
26-
if VERSION >= v"1.9" && CUDA.functional()
27-
@test gradient(A -> sum(A * y), gA)[1] isa CuArray
28-
else
29-
@test_broken gradient(A -> sum(A * y), gA)[1] isa CuArray # fails with JLArray, bug in Zygote?
30-
end
26+
27+
#NOTE: this would require something that can copute gradient... we don't have that here?
28+
#@test gradient(A -> sum(A * y), gA)[1] isa CuArray
29+
30+
# some specialized implementations call only mul! and not *, so we must ensure this works
31+
@test LinearAlgebra.mul!(similar(gA, 3, 3), gA, y) gA*y
32+
@test LinearAlgebra.mul!(similar(gA, 3, 1), gA, onehot(1, 1:2)) gA*onehot(1, 1:2)
33+
34+
@test_throws DimensionMismatch LinearAlgebra.mul!(similar(gA, 3, 4), gA, y)
35+
36+
gB = rand(3, 3) |> cu
37+
@test_throws DimensionMismatch LinearAlgebra.mul!(similar(gB, 3, 3), gB, y)
38+
39+
#TODO: the below fails due to method ambiguity and GPU scalar indexing
40+
y = reshape(y, 3, 2)
41+
gA = rand(2, 3) |> cu
42+
@test_broken LinearAlgebra.mul!(similar(gA, 2, 2), gA, y) gA*y
3143
end
3244

3345
@testset "onehotbatch(::CuArray, ::UnitRange)" begin
@@ -48,7 +60,9 @@ end
4860
y = onehotbatch(ones(3), 1:10) |> cu;
4961
l = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j']
5062
@test onecold(y) isa CuArray
63+
@test onecold(y) == cu([1, 1, 1]) # == doesn't work across devices
5164
@test y[3,:] isa CuArray
65+
@test y[3,:] == cu([0, 0, 0]) # == doesn't work across devices
5266
@test onecold(y, l) == ['a', 'a', 'a']
5367
end
5468

test/linalg.jl

+30
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,17 @@
1212
@test transpose(X) * b2 == transpose(X) * Array(b2)
1313
@test transpose(v) * b2 == transpose(v) * Array(b2)
1414
@test_throws DimensionMismatch A*b2
15+
16+
# in-place with mul!
17+
c1 = fill(NaN, 3)
18+
@test mul!(c1, A, b1) == A * Array(b1)
19+
@test c1 == A * Array(b1)
20+
@test mul!(c1, transpose(A), b1) == transpose(A) * Array(b1)
21+
@test mul!(zeros(3,1), A, b1) == reshape(A * b1, 3,1)
22+
@test mul!([NaN], transpose(v), b2) == mul!([NaN], transpose(v), Array(b2))
23+
24+
@test_throws DimensionMismatch mul!(zeros(5), A, b1)
25+
@test_throws DimensionMismatch mul!(c1, X, b1)
1526
end
1627

1728
@testset "AbstractMatrix-OneHotMatrix multiplication" begin
@@ -22,8 +33,10 @@ end
2233
b2 = OneHotMatrix([2, 4, 1, 3], 5)
2334
b3 = OneHotMatrix([1, 1, 2], 4)
2435
b4 = reshape(OneHotMatrix([1 2 3; 2 2 1], 3), 3, :)
36+
@test OneHotArrays._isonehot(b4)
2537
b5 = reshape(b4, 6, :)
2638
b6 = reshape(OneHotMatrix([1 2 2; 2 2 1], 2), 3, :)
39+
@test !OneHotArrays._isonehot(b6)
2740
b7 = reshape(OneHotMatrix([1 2 3; 1 2 3], 3), 6, :)
2841

2942
@test A * b1 == A[:,[1, 1, 2, 2]]
@@ -41,4 +54,21 @@ end
4154
@test_throws DimensionMismatch A*b2'
4255
@test_throws DimensionMismatch A*b6'
4356
@test_throws DimensionMismatch A*b7
57+
58+
# in-place with mul!
59+
c1 = fill(NaN, 3, 4)
60+
@test mul!(c1, A, b1) == A * b1
61+
@test c1 == A * b1
62+
63+
c4 = fill(NaN, 3, 6)
64+
@test mul!(c4, A, b4) == A * b4 # b4 is reshaped but still one-hot
65+
@test mul!(c4, A', b4) == A' * b4
66+
c6 = fill(NaN, 3, 4)
67+
@test mul!(c6, A, b6) == A * b6 # b4 is reshaped and not one-hot
68+
@test mul!(c6, A', b6) == A' * b6
69+
70+
@test_throws DimensionMismatch mul!(c1, A, b2)
71+
@test_throws DimensionMismatch mul!(c1, A, b4)
72+
@test_throws DimensionMismatch mul!(c4, A, b1)
73+
@test_throws DimensionMismatch mul!(zeros(10, 3), A, b1)
4474
end

test/runtests.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using OneHotArrays
2-
using Test
2+
using Test, LinearAlgebra
33
using Compat: stack
44

55
@testset "OneHotArray" begin

0 commit comments

Comments
 (0)