Skip to content

Commit 88114d4

Browse files
authored
Merge pull request #258 from JuliaGPU/tb/mapreduce_vararg
Specialize varargs version of mapreduce.
2 parents bf2d515 + adb33d0 commit 88114d4

File tree

3 files changed

+15
-6
lines changed

3 files changed

+15
-6
lines changed

src/host/mapreduce.jl

+8-4
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# GPUArrays' mapreduce methods build on `Base.mapreducedim!`, but with an additional
44
# argument `init` value to avoid eager initialization of `R` (if set to something).
5-
mapreducedim!(f, op, R::AbstractGPUArray, A::AbstractArray, init=nothing) = error("Not implemented") # COV_EXCL_LINE
5+
mapreducedim!(f, op, R::AbstractGPUArray, As::AbstractArray...; init=nothing) = error("Not implemented") # COV_EXCL_LINE
66
Base.mapreducedim!(f, op, R::AbstractGPUArray, A::AbstractArray) = mapreducedim!(f, op, R, A)
77

88
neutral_element(op, T) =
@@ -18,11 +18,11 @@ neutral_element(::typeof(Base.mul_prod), T) = one(T)
1818
neutral_element(::typeof(Base.min), T) = typemax(T)
1919
neutral_element(::typeof(Base.max), T) = typemin(T)
2020

21-
function Base.mapreduce(f, op, A::AbstractGPUArray; dims=:, init=nothing)
21+
function Base.mapreduce(f, op, As::AbstractGPUArray...; dims=:, init=nothing)
2222
# figure out the destination container type by looking at the initializer element,
2323
# or by relying on inference to reason through the map and reduce functions.
2424
if init === nothing
25-
ET = Base.promote_op(f, eltype(A))
25+
ET = Base.promote_op(f, map(eltype, As)...)
2626
ET = Base.promote_op(op, ET, ET)
2727
(ET === Union{} || ET === Any) &&
2828
error("mapreduce cannot figure the output element type, please pass an explicit init value")
@@ -32,10 +32,14 @@ function Base.mapreduce(f, op, A::AbstractGPUArray; dims=:, init=nothing)
3232
ET = typeof(init)
3333
end
3434

35+
# TODO: Broadcast-semantics after JuliaLang-julia#31020
36+
A = first(As)
37+
all(B -> size(A) == size(B), As) || throw(DimensionMismatch("dimensions of containers must be identical"))
38+
3539
sz = size(A)
3640
red = ntuple(i->(dims==Colon() || i in dims) ? 1 : sz[i], ndims(A))
3741
R = similar(A, ET, red)
38-
mapreducedim!(f, op, R, A, init)
42+
mapreducedim!(f, op, R, As...; init=init)
3943

4044
if dims==Colon()
4145
@allowscalar R[]

src/reference.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -295,11 +295,11 @@ Adapt.adapt_storage(::Adaptor, x::JLArray{T,N}) where {T,N} =
295295
GPUArrays.unsafe_reinterpret(::Type{T}, A::JLArray, size::Tuple) where T =
296296
reshape(reinterpret(T, A.data), size)
297297

298-
function GPUArrays.mapreducedim!(f, op, R::JLArray, A::AbstractArray, init=nothing)
298+
function GPUArrays.mapreducedim!(f, op, R::JLArray, As::AbstractArray...; init=nothing)
299299
if init !== nothing
300300
fill!(R, init)
301301
end
302-
@allowscalar Base.mapreducedim!(f, op, R.data, A)
302+
@allowscalar Base.reducedim!(op, R.data, map(f, As...))
303303
end
304304

305305
end

test/testsuite/mapreduce.jl

+5
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,11 @@ function test_mapreduce(AT)
136136
AT, rand(range, N, N))
137137
@test compare(x->mapreduce(_addone, +, x; dims = 2, init = _zero),
138138
AT, rand(range, N, N))
139+
140+
@test compare(x->mapreduce(+, +, x; dims = 2),
141+
AT, rand(range, N, N), rand(range, N, N))
142+
@test compare(x->mapreduce(+, +, x; dims = 2, init = _zero),
143+
AT, rand(range, N, N). rand(range, N, N))
139144
end
140145
end
141146
@testset "sum maximum minimum prod" begin

0 commit comments

Comments
 (0)