Skip to content

WIP: Implement broadcasting with AxisArrays on Julia 0.7 #131

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/AxisArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,8 @@ include("indexing.jl")
include("sortedvector.jl")
include("categoricalvector.jl")
include("combine.jl")
@static if VERSION >= v"0.7.0-DEV.2638"
include("broadcast.jl")
end

end
61 changes: 61 additions & 0 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
Base.BroadcastStyle(::Type{<:AxisArray}) = Broadcast.ArrayStyle{AxisArray}()
Base.BroadcastStyle(::Type{<:Adjoint{T, <:AxisArray{T}}}) where T =
Broadcast.ArrayStyle{AxisArray}()

# Hijack broadcasting after determining style
function Base.broadcast(f, ::Broadcast.ArrayStyle{AxisArray}, ::Nothing, ::Nothing, As...)
# We need to make sure we can combine indices of only the AxisArrays before attempting
# broadcasting. The total broadcasting operation may include other AbstractArrays.
# We demand that for a given dimension, the axes values and names must match
# as implemented, this demands exact matching of axes (even floating point nums).
axesAs = Broadcast.combine_indices(axarrs(As)...)

# Obtain the underlying data and find the result indices if we were to
# broadcast all arrays without axis info.
Bs = data(As)

# Broadcast using the underlying data
broadcasted = broadcast(f, Bs...)

defaxesBs = default_axes(broadcasted)
axesBs = broadcax(axesAs, defaxesBs)
return AxisArray(broadcasted, axesBs)
end

broadcax(axes::Tuple, defaxes::Tuple) =
(broadcax1(axes[1], defaxes[1]), broadcax(tail(axes), tail(defaxes))...)
broadcax(axes::Tuple{}, defaxes::Tuple) = ()
broadcax1(::Tuple{}, x) = ()
function broadcax1(axA::Axis, axB::Axis)
axAname, axAvalues = axisname(axA), axisvalues(axA)[1]
axAname != axisname(axB) && return axA
if typeof(axAvalues) <: Base.OneTo
# We believe this was a default axis, not just an axis that happened to
# have the default name
return typeof(axA)(Base.OneTo(length(axB)))
else
error("axis values did not match.")
end
end

# Compares the value indices and axis names (note: AxisArrays.axes, not Base.axes)
Broadcast.broadcast_indices(::Broadcast.ArrayStyle{AxisArray}, A) = axes(A)
Broadcast.broadcast_indices(::Broadcast.ArrayStyle{AxisArray}, A::Adjoint{T,S}) where
{T, S<:AxisArray{T,1}} = (Axis{:row}(Base.OneTo(1)), axes(A.parent)[1])
Broadcast.broadcast_indices(::Broadcast.ArrayStyle{AxisArray}, A::Adjoint{T,S}) where
{T, S<:AxisArray{T,2}} = tupswap(axes(A.parent))

# Helper functions
# Given a tuple `A`, return a tuple containing only the AxisArrays (or their adjoints) in `A`
axarrs(A::Tuple{AxisArray, Vararg}) = (A[1], axarrs(Base.tail(A))...)
axarrs(A::Tuple{Adjoint{T, <:AxisArray} where T, Vararg}) = (A[1], axarrs(Base.tail(A))...)
axarrs(A::Tuple{Any, Vararg}) = axarrs(Base.tail(A))
axarrs(A::Tuple{}) = ()

data(A::Tuple{AxisArray,Vararg}) = (A[1].data, data(Base.tail(A))...)
data(A::Tuple{Adjoint{T, <:AxisArray} where T, Vararg}) =
(adjoint(A[1].parent.data), data(Base.tail(A))...)
data(A::Tuple{Any,Vararg}) = (A[1], data(Base.tail(A))...)
data(A::Tuple{}) = ()

tupswap(A::Tuple{Any,Any}) = (A[2],A[1])
144 changes: 144 additions & 0 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
A0 = [1,2,3]
A = AxisArray(A0, Axis{:abc}([1.0, 2.0, 3.0]))
A1 = AxisArray(A0, Axis{:def}([1.0, 2.0, 3.0]))
A2 = AxisArray(A0, Axis{:abc}([1.0, 2.0, 3.0+eps(3.0)]))

B0 = [1 2 3]
B = AxisArray(B0, Axis{:row}(Base.OneTo(1)), Axis{:def}([1.3, 2.4, 36]))
B1 = AxisArray(B0, Axis{:row}(Base.OneTo(1)), Axis{:abc}([1.0, 2.0, 3.0]))
B2 = AxisArray(B0, Axis{:abc}(Base.OneTo(1)), Axis{:def}([1.3, 2.4, 36]))

C0 = reshape([10])
C = AxisArray(C0)

D0 = ones(Complex, 3, 3)
D = AxisArray(D0, Axis{:abc}([1.0, 2.0, 3.0]), Axis{:def}([1.3, 2.4, 36]))
D1 = AxisArray(D0, Axis{:abc}([1.0, 2.0, 3.0+eps(3.0)]), Axis{:def}([1.3, 2.4, 36]))
D2 = AxisArray(D0, Axis{:row}(Base.OneTo(3)), Axis{:def}([1.3, 2.4, 36]))
D3 = AxisArray(D0, Axis{:abc}([1.0, 2.0, 3.1]), Axis{:def}([1.3, 2.4, 36]))

# AxisArray 0-d + number
@test (C .+ 1) isa AxisArray
@test @inferred(C .+ 1).data == reshape([11])
@test AxisArrays.axes(C .+ 1) == ()
@test (1 .+ C) isa AxisArray
@test @inferred(1 .+ C).data == reshape([11])
@test AxisArrays.axes(1 .+ C) == ()

# AxisArray vector + number
@test (A .+ 1) isa AxisArray
@test @inferred(A .+ 1).data == [2,3,4]
@test AxisArrays.axes(A .+ 1)[1] == Axis{:abc}([1.0, 2.0, 3.0])
@test (1 .+ A) isa AxisArray
@test @inferred(1 .+ A).data == [2,3,4]
@test AxisArrays.axes(1 .+ A)[1] == Axis{:abc}([1.0, 2.0, 3.0])

# AxisArray row-vector + number...
# AxisArray matrix + number...
# AxisArray higher-d + number...

# AxisArray 0-d + AxisArray 0-d
@test (C .+ C) isa AxisArray
@test @inferred(C .+ C).data == reshape([20])
@test AxisArrays.axes(C .+ C) == ()

# AxisArray 0-d + non-AxisArray 0-d
@test (C0 .+ C) isa AxisArray
@test @inferred(C0 .+ C).data == reshape([20])
@test AxisArrays.axes(C0 .+ C) == ()
@test (C .+ C0) isa AxisArray
@test @inferred(C .+ C0).data == reshape([20])
@test AxisArrays.axes(C .+ C0) == ()

# AxisArray vector + AxisArray 0-d
@test @inferred(A .+ C).data == [11,12,13]
@test AxisArrays.axes(A .+ C)[1] == Axis{:abc}([1.0, 2.0, 3.0])
@test length(AxisArrays.axes(A .+ C)) == 1
@test @inferred(C .+ A).data == [11,12,13]
@test AxisArrays.axes(C .+ A)[1] == Axis{:abc}([1.0, 2.0, 3.0])
@test length(AxisArrays.axes(C .+ A)) == 1

# AxisArray vector + non-AxisArray 0-d
@test @inferred(A .+ C0).data == [11,12,13]
@test AxisArrays.axes(A .+ C0)[1] == Axis{:abc}([1.0, 2.0, 3.0])
@test length(AxisArrays.axes(A .+ C0)) == 1
@test @inferred(C0 .+ A).data == [11,12,13]
@test AxisArrays.axes(C0 .+ A)[1] == Axis{:abc}([1.0, 2.0, 3.0])
@test length(AxisArrays.axes(C0 .+ A)) == 1

# AxisArray vector + AxisArray vector
@test @inferred(A .+ A).data == [2,4,6]
@test AxisArrays.axes(A .+ A)[1] == Axis{:abc}([1.0, 2.0, 3.0])
@test length(AxisArrays.axes(A .+ A)) == 1
@test_throws DimensionMismatch (A.+A1) # axis name mismatch
@test_throws DimensionMismatch (A1.+A)
@test_throws DimensionMismatch (A.+A2) # axis value mismatch (floating-points count)
@test_throws DimensionMismatch (A2.+A)

# AxisArray vector + non-AxisArray vector
@test @inferred(A .+ A0).data == [2,4,6]
@test AxisArrays.axes(A .+ A0)[1] == Axis{:abc}([1.0, 2.0, 3.0])
@test length(AxisArrays.axes(A .+ A0)) == 1
@test @inferred(A0 .+ A).data == [2,4,6]
@test AxisArrays.axes(A0 .+ A)[1] == Axis{:abc}([1.0, 2.0, 3.0])
@test length(AxisArrays.axes(A0 .+ A)) == 1

# AxisArray vector + 1xN AxisArray matrix
@test_broken @inferred(A .+ B).data == [2 3 4; 3 4 5; 4 5 6] # output good but axes aren't yet inferred...
@test length(AxisArrays.axes(A .+ B)) == 2
@test AxisArrays.axes(A .+ B)[1] == Axis{:abc}([1.0, 2.0, 3.0])
@test AxisArrays.axes(A .+ B)[2] == Axis{:def}([1.3, 2.4, 36])

@test_broken @inferred(B .+ A).data == [2 3 4; 3 4 5; 4 5 6] # output good but axes aren't yet inferred...
@test length(AxisArrays.axes(B .+ A)) == 2
@test AxisArrays.axes(B .+ A)[1] == Axis{:abc}([1.0, 2.0, 3.0])
@test AxisArrays.axes(B .+ A)[2] == Axis{:def}([1.3, 2.4, 36])

@test_throws ArgumentError (A.+B1) # axis names don't match
@test_throws ArgumentError (B1.+A)
@test_broken @test_throws DimensionMismatch (A.+B2)
@test_broken @test_throws DimensionMismatch (B2.+A)

# AxisArray vector + 1xN non-AxisArray matrix
@test @inferred(A.+B0).data == [2 3 4; 3 4 5; 4 5 6]
@test length(AxisArrays.axes(A .+ B0)) == 2
@test AxisArrays.axes(A .+ B0)[1] == Axis{:abc}([1.0, 2.0, 3.0])
@test AxisArrays.axes(A .+ B0)[2] == Axis{:col}(Base.OneTo(3))

# AxisArray vector + NxN AxisArray matrix
@test_broken @inferred(A .+ D).data ==
[2+0im 2+0im 2+0im;
3+0im 3+0im 3+0im;
4+0im 4+0im 4+0im] # output good but inference dies
@test length(AxisArrays.axes(A .+ D)) == 2
@test AxisArrays.axes(A .+ D)[1] == Axis{:abc}([1.0, 2.0, 3.0])
@test AxisArrays.axes(A .+ D)[2] == Axis{:def}([1.3, 2.4, 36])
@test_broken @inferred(D .+ A).data ==
[2+0im 2+0im 2+0im;
3+0im 3+0im 3+0im;
4+0im 4+0im 4+0im] # output good but inference dies
@test length(AxisArrays.axes(D .+ A)) == 2
@test AxisArrays.axes(D .+ A)[1] == Axis{:abc}([1.0, 2.0, 3.0])
@test AxisArrays.axes(D .+ A)[2] == Axis{:def}([1.3, 2.4, 36])
@test_throws DimensionMismatch (A.+D1)
@test_throws DimensionMismatch (D1.+A)
@test_throws DimensionMismatch (A.+D2)
@test_throws DimensionMismatch (D2.+A)
@test_throws DimensionMismatch (A.+D3)
@test_throws DimensionMismatch (D3.+A)

# AxisArray vector + NxN non-AxisArray matrix
@test_broken @inferred(A .+ D0).data ==
[2+0im 2+0im 2+0im;
3+0im 3+0im 3+0im;
4+0im 4+0im 4+0im] # output good but inference dies
@test length(AxisArrays.axes(A .+ D0)) == 2
@test AxisArrays.axes(A .+ D0)[1] == Axis{:abc}([1.0, 2.0, 3.0])
@test AxisArrays.axes(A .+ D0)[2] == Axis{:col}(Base.OneTo(3))
@test_broken @inferred(D0 .+ A).data ==
[2+0im 2+0im 2+0im;
3+0im 3+0im 3+0im;
4+0im 4+0im 4+0im] # output good but inference dies
@test length(AxisArrays.axes(D0 .+ A)) == 2
@test AxisArrays.axes(D0 .+ A)[1] == Axis{:abc}([1.0, 2.0, 3.0])
@test AxisArrays.axes(D0 .+ A)[2] == Axis{:col}(Base.OneTo(3))
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ import IterTools
include("combine.jl")
end

@static if VERSION >= v"0.7.0-DEV.2638"
@testset "Broadcast" begin
include("broadcast.jl")
end
end

@testset "README" begin
include("readme.jl")
end
Expand Down