Skip to content

Commit cf6ad3e

Browse files
authored
fix: improve performance of unbatched out-of-place Jacobian (#876)
1 parent 2bc92ce commit cf6ad3e

File tree

3 files changed

+86
-1
lines changed

3 files changed

+86
-1
lines changed

DifferentiationInterface/src/first_order/jacobian.jl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,35 @@ function _jacobian_aux(
372372
end
373373
end
374374

375+
function _jacobian_aux(
376+
f_or_f!y::FY,
377+
prep::PushforwardJacobianPrep{SIG, <:BatchSizeSettings{1, false, true}},
378+
backend::AbstractADType,
379+
x,
380+
contexts::Vararg{Context, C},
381+
) where {FY, SIG, C}
382+
(; batched_seeds, seed_example, pushforward_prep) = prep
383+
384+
pushforward_prep_same = prepare_pushforward_same_point(
385+
f_or_f!y..., pushforward_prep, backend, x, seed_example, contexts...
386+
)
387+
388+
jac = stack(eachindex(batched_seeds); dims = 2) do a
389+
dy = only(
390+
pushforward(
391+
f_or_f!y...,
392+
pushforward_prep_same,
393+
backend,
394+
x,
395+
batched_seeds[a],
396+
contexts...,
397+
)
398+
)
399+
return vec(dy)
400+
end
401+
return jac
402+
end
403+
375404
function _jacobian_aux(
376405
f_or_f!y::FY,
377406
prep::PushforwardJacobianPrep{SIG, <:BatchSizeSettings{B, false, aligned}},
@@ -428,6 +457,34 @@ function _jacobian_aux(
428457
end
429458
end
430459

460+
function _jacobian_aux(
461+
f_or_f!y::FY,
462+
prep::PullbackJacobianPrep{SIG, <:BatchSizeSettings{1, false, true}},
463+
backend::AbstractADType,
464+
x,
465+
contexts::Vararg{Context, C},
466+
) where {FY, SIG, C}
467+
(; batched_seeds, seed_example, pullback_prep) = prep
468+
469+
pullback_prep_same = prepare_pullback_same_point(
470+
f_or_f!y..., pullback_prep, backend, x, seed_example, contexts...
471+
)
472+
473+
jac = stack(eachindex(batched_seeds); dims = 1) do a
474+
dx = only(
475+
pullback(
476+
f_or_f!y..., pullback_prep_same, backend, x, batched_seeds[a], contexts...
477+
)
478+
)
479+
if eltype(x) <: Complex
480+
return map(conj, vec(dx))
481+
else
482+
return vec(dx)
483+
end
484+
end
485+
return jac
486+
end
487+
431488
function _jacobian_aux(
432489
f_or_f!y::FY,
433490
prep::PullbackJacobianPrep{SIG, <:BatchSizeSettings{B, false, aligned}},

DifferentiationInterface/src/second_order/hessian.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,27 @@ function hessian(
151151
return block
152152
end
153153

154+
function hessian(
155+
f::F,
156+
prep::HVPGradientHessianPrep{SIG, <:BatchSizeSettings{1, false, true}},
157+
backend::AbstractADType,
158+
x,
159+
contexts::Vararg{Context, C},
160+
) where {F, SIG, C}
161+
check_prep(f, prep, backend, x, contexts...)
162+
(; batched_seeds, seed_example, hvp_prep) = prep
163+
164+
hvp_prep_same = prepare_hvp_same_point(
165+
f, hvp_prep, backend, x, seed_example, contexts...
166+
)
167+
168+
hess = mapreduce(hcat, eachindex(batched_seeds)) do a
169+
dg = only(hvp(f, hvp_prep_same, backend, x, batched_seeds[a], contexts...))
170+
return vec(dg)
171+
end
172+
return hess
173+
end
174+
154175
function hessian(
155176
f::F,
156177
prep::HVPGradientHessianPrep{SIG, <:BatchSizeSettings{B, false, aligned}},

DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,14 @@ end
7474
logging = LOGGING,
7575
)
7676

77-
test_differentiation(backends, complex_scenarios(); logging = LOGGING)
77+
test_differentiation(
78+
vcat(
79+
backends[2:3],
80+
AutoReverseFromPrimitive(AutoSimpleFiniteDiff(; chunksize = 1))
81+
),
82+
complex_scenarios();
83+
logging = LOGGING
84+
)
7885
end
7986

8087
@testset "Sparse" begin

0 commit comments

Comments
 (0)