@@ -420,11 +420,13 @@ end
420
420
# #### `mapfoldl(f, g, ::Tuple)`
421
421
# ####
422
422
423
+ using Base: mapfoldl_impl
424
+
423
425
# For tuples there should be no harm in handling `map` first.
424
426
# This will also catch `mapreduce`.
425
427
426
428
function rrule (
427
- cfg:: RuleConfig{>:HasReverseMode} , :: typeof (Base . mapfoldl_impl), f:: F , op:: G , init, x:: Tuple ;
429
+ cfg:: RuleConfig{>:HasReverseMode} , :: typeof (mapfoldl_impl), f:: F , op:: G , init, x:: Tuple ;
428
430
) where {F,G}
429
431
y, backmap = rrule (cfg, map, f, x)
430
432
z, backred = rrule (cfg, Base. mapfoldl_impl, identity, op, init, y)
@@ -436,6 +438,11 @@ function rrule(
436
438
return z, mapfoldl_pullback_tuple
437
439
end
438
440
441
+ function rrule (:: RuleConfig{>:HasReverseMode} , :: typeof (mapfoldl_impl), f, op, init, x:: Tuple{} )
442
+ foldl_pullback_empty (dy) = (NoTangent (), NoTangent (), NoTangent (), dy, NoTangent ())
443
+ return init, foldl_pullback_empty
444
+ end
445
+
439
446
# ####
440
447
# #### `foldl(f, ::Tuple)`
441
448
# ####
@@ -495,6 +502,12 @@ function rrule(
495
502
return y, foldl_pullback_tuple_init
496
503
end
497
504
505
+ # Base.tail doesn't work on (), trivial case:
506
+ function rrule (:: RuleConfig{>:HasReverseMode} , :: typeof (mapfoldl_impl), :: typeof (identity), op, init, x:: Tuple{} )
507
+ foldl_pullback_empty (dy) = (NoTangent (), NoTangent (), NoTangent (), dy, NoTangent ())
508
+ return init, foldl_pullback_empty
509
+ end
510
+
498
511
# ####
499
512
# #### `foldl(f, ::Array)`
500
513
# ####
0 commit comments