Skip to content

Commit 1b5d7a1

Browse files
committed
fix #672
1 parent e8bfd72 commit 1b5d7a1

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

src/rulesets/Base/mapreduce.jl

+14-1
Original file line numberDiff line numberDiff line change
@@ -420,11 +420,13 @@ end
420420
##### `mapfoldl(f, g, ::Tuple)`
421421
#####
422422

423+
using Base: mapfoldl_impl
424+
423425
# For tuples there should be no harm in handling `map` first.
424426
# This will also catch `mapreduce`.
425427

426428
function rrule(
427-
cfg::RuleConfig{>:HasReverseMode}, ::typeof(Base.mapfoldl_impl), f::F, op::G, init, x::Tuple;
429+
cfg::RuleConfig{>:HasReverseMode}, ::typeof(mapfoldl_impl), f::F, op::G, init, x::Tuple;
428430
) where {F,G}
429431
y, backmap = rrule(cfg, map, f, x)
430432
z, backred = rrule(cfg, Base.mapfoldl_impl, identity, op, init, y)
@@ -436,6 +438,11 @@ function rrule(
436438
return z, mapfoldl_pullback_tuple
437439
end
438440

441+
function rrule(::RuleConfig{>:HasReverseMode}, ::typeof(mapfoldl_impl), f, op, init, x::Tuple{})
442+
foldl_pullback_empty(dy) = (NoTangent(), NoTangent(), NoTangent(), dy, NoTangent())
443+
return init, foldl_pullback_empty
444+
end
445+
439446
#####
440447
##### `foldl(f, ::Tuple)`
441448
#####
@@ -495,6 +502,12 @@ function rrule(
495502
return y, foldl_pullback_tuple_init
496503
end
497504

505+
# Base.tail doesn't work on (), trivial case:
506+
function rrule(::RuleConfig{>:HasReverseMode}, ::typeof(mapfoldl_impl), ::typeof(identity), op, init, x::Tuple{})
507+
foldl_pullback_empty(dy) = (NoTangent(), NoTangent(), NoTangent(), dy, NoTangent())
508+
return init, foldl_pullback_empty
509+
end
510+
498511
#####
499512
##### `foldl(f, ::Array)`
500513
#####

test/rulesets/Base/mapreduce.jl

+4
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,10 @@ const _INIT = Base._InitialValue()
303303
# Finite differencing
304304
test_rrule(mapfoldl_impl, identity, /, _INIT, Tuple(1 .+ rand(5)))
305305
test_rrule(mapfoldl_impl, identity, *, 1+rand(), Tuple(rand(ComplexF64, 5)))
306+
307+
# Trivial case
308+
test_rrule(mapfoldl_impl, identity, /, 2pi, ())
309+
test_rrule(mapfoldl_impl, sqrt, /, 2pi, ())
306310
end
307311
@testset "mapfoldl(f, g, ::Tuple)" begin
308312
test_rrule(mapfoldl_impl, cbrt, /, _INIT, Tuple(1 .+ rand(5)), check_inferred=false)

0 commit comments

Comments
 (0)