2
2
3
3
# GPUArrays' mapreduce methods build on `Base.mapreducedim!`, but with an additional
4
4
# argument `init` value to avoid eager initialization of `R` (if set to something).
5
- mapreducedim! (f, op, R:: AbstractGPUArray , A :: AbstractArray , init= nothing ) = error (" Not implemented" ) # COV_EXCL_LINE
5
+ mapreducedim! (f, op, R:: AbstractGPUArray , As :: AbstractArray... ; init= nothing ) = error (" Not implemented" ) # COV_EXCL_LINE
6
6
Base. mapreducedim! (f, op, R:: AbstractGPUArray , A:: AbstractArray ) = mapreducedim! (f, op, R, A)
7
7
8
8
neutral_element (op, T) =
@@ -18,11 +18,11 @@ neutral_element(::typeof(Base.mul_prod), T) = one(T)
18
18
neutral_element (:: typeof (Base. min), T) = typemax (T)
19
19
neutral_element (:: typeof (Base. max), T) = typemin (T)
20
20
21
- function Base. mapreduce (f, op, A :: AbstractGPUArray ; dims= :, init= nothing )
21
+ function Base. mapreduce (f, op, As :: AbstractGPUArray... ; dims= :, init= nothing )
22
22
# figure out the destination container type by looking at the initializer element,
23
23
# or by relying on inference to reason through the map and reduce functions.
24
24
if init === nothing
25
- ET = Base. promote_op (f, eltype (A) )
25
+ ET = Base. promote_op (f, map (eltype, As) ... )
26
26
ET = Base. promote_op (op, ET, ET)
27
27
(ET === Union{} || ET === Any) &&
28
28
error (" mapreduce cannot figure the output element type, please pass an explicit init value" )
@@ -32,10 +32,14 @@ function Base.mapreduce(f, op, A::AbstractGPUArray; dims=:, init=nothing)
32
32
ET = typeof (init)
33
33
end
34
34
35
+ # TODO : Broadcast-semantics after JuliaLang-julia#31020
36
+ A = first (As)
37
+ all (B -> size (A) == size (B), As) || throw (DimensionMismatch (" dimensions of containers must be identical" ))
38
+
35
39
sz = size (A)
36
40
red = ntuple (i-> (dims== Colon () || i in dims) ? 1 : sz[i], ndims (A))
37
41
R = similar (A, ET, red)
38
- mapreducedim! (f, op, R, A, init)
42
+ mapreducedim! (f, op, R, As ... ; init = init)
39
43
40
44
if dims== Colon ()
41
45
@allowscalar R[]
0 commit comments