Skip to content

Commit a059e19

Browse files
authored
fix: handle constant ConstantOrCache with Enzyme and SCT (#753)
* fix: handle constant `ConstantOrCache` with Enzyme * Fixes
1 parent cfab84d commit a059e19

File tree

9 files changed

+73
-67
lines changed

9 files changed

+73
-67
lines changed

DifferentiationInterface/Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.6.47"
4+
version = "0.6.48"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl

+13-1
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,27 @@ force_annotation(f::F) where {F} = Const(f)
5454
end
5555

5656
@inline function _translate(
57-
backend::AutoEnzyme, ::Mode, ::Val{B}, c::Union{DI.Cache,DI.GeneralizedConstantOrCache}
57+
backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.Cache
5858
) where {B}
59+
# important to keep make_zero here for ConstantOrCache instead of similar
5960
if B == 1
6061
return Duplicated(DI.unwrap(c), make_zero(DI.unwrap(c)))
6162
else
6263
return BatchDuplicated(DI.unwrap(c), ntuple(_ -> make_zero(DI.unwrap(c)), Val(B)))
6364
end
6465
end
6566

67+
@inline function _translate(
68+
backend::AutoEnzyme, mode::Mode, valB::Val{B}, c::DI.ConstantOrCache
69+
) where {B}
70+
IA = guess_activity(typeof(DI.unwrap(c)), mode)
71+
if IA <: Const
72+
return _translate(backend, mode, valB, DI.Constant(DI.unwrap(c)))
73+
else
74+
return _translate(backend, mode, valB, DI.Cache(DI.unwrap(c)))
75+
end
76+
end
77+
6678
@inline function _translate(
6779
backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.FunctionContext
6880
) where {B}

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl

+3-5
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ function mypartials!(::Type{T}, ty::NTuple{B}, ydual) where {T,B}
8383
end
8484

8585
function _translate(
86-
::Type{D}, c::Union{DI.GeneralizedConstant,DI.GeneralizedConstantOrCache}
86+
::Type{D}, c::Union{DI.GeneralizedConstant,DI.ConstantOrCache}
8787
) where {D<:Dual}
8888
return DI.unwrap(c)
8989
end
@@ -100,7 +100,7 @@ function translate(::Type{D}, contexts::NTuple{C,DI.Context}) where {D<:Dual,C}
100100
end
101101

102102
function _translate_toprep(
103-
::Type{D}, c::Union{DI.GeneralizedConstant,DI.GeneralizedConstantOrCache}
103+
::Type{D}, c::Union{DI.GeneralizedConstant,DI.ConstantOrCache}
104104
) where {D<:Dual}
105105
return nothing
106106
end
@@ -116,9 +116,7 @@ function translate_toprep(::Type{D}, contexts::NTuple{C,DI.Context}) where {D<:D
116116
return new_contexts
117117
end
118118

119-
function _translate_prepared(
120-
c::Union{DI.GeneralizedConstant,DI.GeneralizedConstantOrCache}, _pc
121-
)
119+
function _translate_prepared(c::Union{DI.GeneralizedConstant,DI.ConstantOrCache}, _pc)
122120
return DI.unwrap(c)
123121
end
124122
_translate_prepared(_c::DI.Cache, pc) = pc

DifferentiationInterface/ext/DifferentiationInterfaceSparseConnectivityTracerExt/DifferentiationInterfaceSparseConnectivityTracerExt.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ import DifferentiationInterface as DI
55
using SparseConnectivityTracer:
66
TracerSparsityDetector, TracerLocalSparsityDetector, jacobian_buffer, hessian_buffer
77

8-
@inline _translate(::Type, c::DI.Constant) = DI.unwrap(c)
8+
@inline function _translate(::Type, c::Union{DI.GeneralizedConstant,DI.ConstantOrCache})
9+
return DI.unwrap(c)
10+
end
911
@inline function _translate(::Type{T}, c::DI.Cache) where {T}
1012
return DI.recursive_similar(DI.unwrap(c), T)
1113
end

DifferentiationInterface/src/second_order/hvp.jl

+31-31
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ function _prepare_hvp_aux(
117117
rewrap = Rewrap(contexts...)
118118
# Outer pushforward
119119
new_contexts = (
120-
FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts...
120+
FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts...
121121
)
122122
outer_pushforward_prep = prepare_pushforward_nokwarg(
123123
strict, shuffled_gradient, outer(backend), x, tx, new_contexts...
@@ -161,15 +161,15 @@ function _prepare_hvp_aux(
161161
# Outer pushforward
162162
new_contexts = (
163163
FunctionContext(f),
164-
PrepContext(inner_gradient_prep),
165-
BackendContext(inner(backend)),
164+
ConstantOrCache(inner_gradient_prep),
165+
Constant(inner(backend)),
166166
Constant(rewrap),
167167
contexts...,
168168
)
169169
new_contexts_in = (
170170
FunctionContext(f),
171-
PrepContext(inner_gradient_in_prep),
172-
BackendContext(inner(backend)),
171+
ConstantOrCache(inner_gradient_in_prep),
172+
Constant(inner(backend)),
173173
Constant(rewrap),
174174
contexts...,
175175
)
@@ -228,15 +228,15 @@ function _prepare_hvp_aux(
228228
# Outer pushforward
229229
new_contexts = (
230230
FunctionContext(f),
231-
PrepContext(inner_gradient_prep),
232-
BackendContext(inner(backend)),
231+
ConstantOrCache(inner_gradient_prep),
232+
Constant(inner(backend)),
233233
Constant(rewrap),
234234
contexts...,
235235
)
236236
new_contexts_in = (
237237
FunctionContext(f),
238-
PrepContext(inner_gradient_in_prep),
239-
BackendContext(inner(backend)),
238+
ConstantOrCache(inner_gradient_in_prep),
239+
Constant(inner(backend)),
240240
Constant(rewrap),
241241
contexts...,
242242
)
@@ -279,8 +279,8 @@ function hvp(
279279
rewrap = Rewrap(contexts...)
280280
new_contexts = (
281281
FunctionContext(f),
282-
map(PrepContext, maybe_inner_gradient_prep)...,
283-
BackendContext(inner(backend)),
282+
map(ConstantOrCache, maybe_inner_gradient_prep)...,
283+
Constant(inner(backend)),
284284
Constant(rewrap),
285285
contexts...,
286286
)
@@ -318,8 +318,8 @@ function _hvp_aux!(
318318
rewrap = Rewrap(contexts...)
319319
new_contexts = (
320320
FunctionContext(f),
321-
map(PrepContext, maybe_inner_gradient_in_prep)...,
322-
BackendContext(inner(backend)),
321+
map(ConstantOrCache, maybe_inner_gradient_in_prep)...,
322+
Constant(inner(backend)),
323323
Constant(rewrap),
324324
contexts...,
325325
)
@@ -349,8 +349,8 @@ function _hvp_aux!(
349349
rewrap = Rewrap(contexts...)
350350
new_contexts = (
351351
FunctionContext(f),
352-
map(PrepContext, maybe_inner_gradient_prep)...,
353-
BackendContext(inner(backend)),
352+
map(ConstantOrCache, maybe_inner_gradient_prep)...,
353+
Constant(inner(backend)),
354354
Constant(rewrap),
355355
contexts...,
356356
)
@@ -378,8 +378,8 @@ function gradient_and_hvp(
378378
rewrap = Rewrap(contexts...)
379379
new_contexts = (
380380
FunctionContext(f),
381-
map(PrepContext, maybe_inner_gradient_prep)...,
382-
BackendContext(inner(backend)),
381+
map(ConstantOrCache, maybe_inner_gradient_prep)...,
382+
Constant(inner(backend)),
383383
Constant(rewrap),
384384
contexts...,
385385
)
@@ -419,8 +419,8 @@ function _gradient_and_hvp_aux!(
419419
rewrap = Rewrap(contexts...)
420420
new_contexts = (
421421
FunctionContext(f),
422-
map(PrepContext, maybe_inner_gradient_in_prep)...,
423-
BackendContext(inner(backend)),
422+
map(ConstantOrCache, maybe_inner_gradient_in_prep)...,
423+
Constant(inner(backend)),
424424
Constant(rewrap),
425425
contexts...,
426426
)
@@ -452,8 +452,8 @@ function _gradient_and_hvp_aux!(
452452
rewrap = Rewrap(contexts...)
453453
new_contexts = (
454454
FunctionContext(f),
455-
map(PrepContext, maybe_inner_gradient_prep)...,
456-
BackendContext(inner(backend)),
455+
map(ConstantOrCache, maybe_inner_gradient_prep)...,
456+
Constant(inner(backend)),
457457
Constant(rewrap),
458458
contexts...,
459459
)
@@ -492,7 +492,7 @@ function _prepare_hvp_aux(
492492
rewrap = Rewrap(contexts...)
493493
new_contexts = (
494494
FunctionContext(f),
495-
BackendContext(inner(backend)),
495+
Constant(inner(backend)),
496496
Constant(first(tx)),
497497
Constant(rewrap),
498498
contexts...,
@@ -522,7 +522,7 @@ function hvp(
522522
outer(backend),
523523
x,
524524
FunctionContext(f),
525-
BackendContext(inner(backend)),
525+
Constant(inner(backend)),
526526
Constant(dx),
527527
Constant(rewrap),
528528
contexts...,
@@ -551,7 +551,7 @@ function hvp!(
551551
outer(backend),
552552
x,
553553
FunctionContext(f),
554-
BackendContext(inner(backend)),
554+
Constant(inner(backend)),
555555
Constant(tx[b]),
556556
Constant(rewrap),
557557
contexts...,
@@ -613,7 +613,7 @@ function _prepare_hvp_aux(
613613
_sig = signature(f, backend, x, tx, contexts...; strict)
614614
rewrap = Rewrap(contexts...)
615615
new_contexts = (
616-
FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts...
616+
FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts...
617617
)
618618
grad_buffer = similar(x)
619619
outer_pullback_prep = prepare_pullback_nokwarg(
@@ -649,7 +649,7 @@ function hvp(
649649
(; outer_pullback_prep) = prep
650650
rewrap = Rewrap(contexts...)
651651
new_contexts = (
652-
FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts...
652+
FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts...
653653
)
654654
return pullback(
655655
shuffled_gradient, outer_pullback_prep, outer(backend), x, tx, new_contexts...
@@ -684,7 +684,7 @@ function _hvp_aux!(
684684
(; grad_buffer, outer_pullback_in_prep) = prep
685685
rewrap = Rewrap(contexts...)
686686
new_contexts = (
687-
FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts...
687+
FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts...
688688
)
689689
return pullback!(
690690
shuffled_gradient!,
@@ -711,7 +711,7 @@ function _hvp_aux!(
711711
(; outer_pullback_prep) = prep
712712
rewrap = Rewrap(contexts...)
713713
new_contexts = (
714-
FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts...
714+
FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts...
715715
)
716716
return pullback!(
717717
shuffled_gradient, tg, outer_pullback_prep, outer(backend), x, tx, new_contexts...
@@ -730,7 +730,7 @@ function gradient_and_hvp(
730730
(; outer_pullback_prep) = prep
731731
rewrap = Rewrap(contexts...)
732732
new_contexts = (
733-
FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts...
733+
FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts...
734734
)
735735
return value_and_pullback(
736736
shuffled_gradient, outer_pullback_prep, outer(backend), x, tx, new_contexts...
@@ -767,7 +767,7 @@ function _gradient_and_hvp_aux!(
767767
(; outer_pullback_in_prep) = prep
768768
rewrap = Rewrap(contexts...)
769769
new_contexts = (
770-
FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts...
770+
FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts...
771771
)
772772
new_grad, _ = value_and_pullback!(
773773
shuffled_gradient!,
@@ -796,7 +796,7 @@ function _gradient_and_hvp_aux!(
796796
(; outer_pullback_prep) = prep
797797
rewrap = Rewrap(contexts...)
798798
new_contexts = (
799-
FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts...
799+
FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts...
800800
)
801801
new_grad, _ = value_and_pullback!(
802802
shuffled_gradient, tg, outer_pullback_prep, outer(backend), x, tx, new_contexts...

DifferentiationInterface/src/second_order/second_derivative.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ function prepare_second_derivative_nokwarg(
6767
_sig = signature(f, backend, x, contexts...; strict)
6868
rewrap = Rewrap(contexts...)
6969
new_contexts = (
70-
FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts...
70+
FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts...
7171
)
7272
outer_derivative_prep = prepare_derivative_nokwarg(
7373
strict, shuffled_derivative, outer(backend), x, new_contexts...
@@ -88,7 +88,7 @@ function second_derivative(
8888
(; outer_derivative_prep) = prep
8989
rewrap = Rewrap(contexts...)
9090
new_contexts = (
91-
FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts...
91+
FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts...
9292
)
9393
return derivative(
9494
shuffled_derivative, outer_derivative_prep, outer(backend), x, new_contexts...
@@ -106,7 +106,7 @@ function value_derivative_and_second_derivative(
106106
(; outer_derivative_prep) = prep
107107
rewrap = Rewrap(contexts...)
108108
new_contexts = (
109-
FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts...
109+
FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts...
110110
)
111111
y = f(x, map(unwrap, contexts)...)
112112
der, der2 = value_and_derivative(
@@ -127,7 +127,7 @@ function second_derivative!(
127127
(; outer_derivative_prep) = prep
128128
rewrap = Rewrap(contexts...)
129129
new_contexts = (
130-
FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts...
130+
FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts...
131131
)
132132
return derivative!(
133133
shuffled_derivative, der2, outer_derivative_prep, outer(backend), x, new_contexts...
@@ -147,7 +147,7 @@ function value_derivative_and_second_derivative!(
147147
(; outer_derivative_prep) = prep
148148
rewrap = Rewrap(contexts...)
149149
new_contexts = (
150-
FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts...
150+
FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts...
151151
)
152152
y = f(x, map(unwrap, contexts)...)
153153
new_der, _ = value_and_derivative!(

DifferentiationInterface/src/utils/context.jl

+1-22
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ Abstract supertype for additional context arguments, which can be passed to diff
1212
abstract type Context end
1313

1414
abstract type GeneralizedConstant <: Context end
15-
abstract type GeneralizedConstantOrCache <: Context end
1615

1716
unwrap(c::Context) = c.data
1817
Base.:(==)(c1::Context, c2::Context) = unwrap(c1) == unwrap(c2)
@@ -102,7 +101,7 @@ Concrete type of [`Context`](@ref) argument which can contain a mixture of const
102101
103102
Unlike for [`Cache`](@ref), it is up to the user to ensure that the internal storage can adapt to the required element types, for instance by using [PreallocationTools.jl](https://github.com/SciML/PreallocationTools.jl) directly.
104103
"""
105-
struct ConstantOrCache{T} <: GeneralizedConstantOrCache
104+
struct ConstantOrCache{T} <: Context
106105
data::T
107106
end
108107

@@ -123,26 +122,6 @@ struct FunctionContext{T} <: GeneralizedConstant
123122
data::T
124123
end
125124

126-
"""
127-
BackendContext
128-
129-
Private type of [`Context`](@ref) argument used for passing backends inside second-order differentiation.
130-
"""
131-
struct BackendContext{T} <: GeneralizedConstant
132-
data::T
133-
end
134-
135-
"""
136-
PrepContext
137-
138-
Private type of [`Context`](@ref) argument used for passing preparation results inside second-order differentiation.
139-
140-
Conceptually similar to [`ConstantOrCache`](@ref) because we assume that preparation was performed with the right types so we don't change anything.
141-
"""
142-
struct PrepContext{T} <: GeneralizedConstantOrCache
143-
data::T
144-
end
145-
146125
## Context manipulation
147126

148127
"""

DifferentiationInterface/test/Back/Enzyme/test.jl

+12
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using ADTypes: ADTypes
55
using DifferentiationInterface, DifferentiationInterfaceTest
66
import DifferentiationInterfaceTest as DIT
77
using Enzyme: Enzyme
8+
using LinearAlgebra
89
using StaticArrays
910
using Test
1011

@@ -136,3 +137,14 @@ end
136137
logging=LOGGING,
137138
)
138139
end
140+
141+
@testset "Coverage" begin
142+
# ConstantOrCache without cache
143+
f_nocontext(x, p) = x
144+
@test I == DifferentiationInterface.jacobian(
145+
f_nocontext, AutoEnzyme(; mode=Enzyme.Forward), rand(10), ConstantOrCache(nothing)
146+
)
147+
@test I == DifferentiationInterface.jacobian(
148+
f_nocontext, AutoEnzyme(; mode=Enzyme.Reverse), rand(10), ConstantOrCache(nothing)
149+
)
150+
end

0 commit comments

Comments
 (0)