diff --git a/DifferentiationInterface/src/first_order/pullback.jl b/DifferentiationInterface/src/first_order/pullback.jl index f9082000e..eba5b8628 100644 --- a/DifferentiationInterface/src/first_order/pullback.jl +++ b/DifferentiationInterface/src/first_order/pullback.jl @@ -168,7 +168,9 @@ function _pullback_via_pushforward( dy, contexts::Vararg{Context,C}, ) where {F,C} - dx = map(CartesianIndices(x)) do j + ind = CartesianIndices(x) + T = typeof(similar(x, eltype(ind))) + dx = map(x, T(ind)) do xj, j t1 = pushforward(f, pushforward_prep, backend, x, (basis(x, j),), contexts...) dot(only(t1), dy) end @@ -254,7 +256,9 @@ function _pullback_via_pushforward( dy, contexts::Vararg{Context,C}, ) where {F,C} - dx = map(CartesianIndices(x)) do j # preserve shape + ind = CartesianIndices(x) + T = typeof(similar(x, eltype(ind))) + dx = map(x, T(ind)) do xj, j # preserve shape t1 = pushforward(f!, y, pushforward_prep, backend, x, (basis(x, j),), contexts...) dot(only(t1), dy) end diff --git a/DifferentiationInterface/src/first_order/pushforward.jl b/DifferentiationInterface/src/first_order/pushforward.jl index 69b253b0b..d8e4a547a 100644 --- a/DifferentiationInterface/src/first_order/pushforward.jl +++ b/DifferentiationInterface/src/first_order/pushforward.jl @@ -171,7 +171,9 @@ function _pushforward_via_pullback( dx, contexts::Vararg{Context,C}, ) where {F,C} - dy = map(CartesianIndices(y)) do i + ind = CartesianIndices(y) + T = typeof(similar(y, eltype(ind))) + dy = map(y, T(ind)) do yi, i t1 = pullback(f, pullback_prep, backend, x, (basis(y, i),), contexts...) dot(only(t1), dx) end @@ -243,7 +245,9 @@ function _pushforward_via_pullback( dx, contexts::Vararg{Context,C}, ) where {F,C} - dy = map(CartesianIndices(y)) do i # preserve shape + ind = CartesianIndices(y) + T = typeof(similar(y, eltype(ind))) + dy = map(y, T(ind)) do yi, i # preserve shape t1 = pullback(f!, y, pullback_prep, backend, x, (basis(y, i),), contexts...) dot(only(t1), dx) end