Skip to content

Commit 202c86f

Browse files
committed
Use Zygote.jacobian etc.
1 parent afec712 commit 202c86f

File tree

3 files changed

+46
-7
lines changed

3 files changed

+46
-7
lines changed

ext/AbstractDifferentiationZygoteExt.jl

+18-2
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,27 @@ else
88
using ..Zygote: Zygote
99
end
1010

11-
AD.ZygoteBackend() = AD.ReverseRuleConfigBackend(Zygote.ZygoteRuleConfig())
12-
1311
# Context should not persist between different AD calls: fixes #69
1412
function AD.ruleconfig(::AD.ReverseRuleConfigBackend{<:Zygote.ZygoteRuleConfig})
1513
return Zygote.ZygoteRuleConfig()
1614
end
1715

16+
function AD.value_and_pullback_function(::AD.ZygoteBackend, f, args...)
17+
return Zygote.pullback(f, args...)
18+
end
19+
20+
AD.gradient(::AD.ZygoteBackend, f, args...) = Zygote.gradient(f, args...)
21+
function AD.value_and_gradient(::AD.ZygoteBackend, f, args...)
22+
res = Zygote.withgradient(f, args...)
23+
return res.val, res.grad
24+
end
25+
26+
AD.jacobian(::AD.ZygoteBackend, f, args...) = Zygote.jacobian(f, args...)
27+
function AD.value_and_jacobian(::AD.ZygoteBackend, f, args...)
28+
res = Zygote.withjacobian(f, args...)
29+
return res.val, res.grad
30+
end
31+
32+
AD.hessian(::AD.ZygoteBackend, f, arg) = Zygote.hessian(f, arg)
33+
1834
end # module

src/backends.jl

+4-3
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,14 @@ end
7171
ruleconfig(ba::ReverseRuleConfigBackend) = ba.ruleconfig
7272

7373
"""
74-
ZygoteBackend()
74+
ZygoteBackend
7575
7676
Create an AD backend that uses reverse mode with [Zygote.jl](https://github.com/FluxML/Zygote.jl).
7777
78-
It is a special case of [`ReverseRuleConfigBackend`](@ref).
78+
Alternatively, you can perform AD with Zygote using a special [`ReverseRuleConfigBackend`](@ref), namely `ReverseRuleConfigBackend(Zygote.ZygoteRuleConfig())`.
79+
Note, however, that the behaviour of this backend is not equivalent to `ZygoteBackend()` since the former uses a generic implementation of jacobian etc. for ChainRules-compatible AD backends whereas `ZygoteBackend` uses implementations in Zygote.jl.
7980
8081
!!! note
8182
To be able to use this backend, you have to load Zygote.
8283
"""
83-
function ZygoteBackend end
84+
struct ZygoteBackend <: AbstractReverseMode end

test/ruleconfig.jl

+24-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@ using Test
44
using Zygote
55

66
@testset "ReverseRuleConfigBackend(ZygoteRuleConfig())" begin
7-
backends = [@inferred(AD.ZygoteBackend())]
7+
backends = [
8+
@inferred(AD.ZygoteBackend()),
9+
@inferred(AD.ReverseRuleConfigBackend(Zygote.ZygoteRuleConfig()))
10+
]
811
@testset for backend in backends
912
@testset "Derivative" begin
1013
test_derivatives(backend)
@@ -34,7 +37,7 @@ using Zygote
3437

3538
# issue #69
3639
@testset "Zygote context" begin
37-
ad = AD.ZygoteBackend()
40+
ad = AD.ReverseRuleConfigBackend(Zygote.ZygoteRuleConfig())
3841

3942
# example in #69: context is not mutated
4043
@test ad.ruleconfig.context.cache === nothing
@@ -53,6 +56,13 @@ using Zygote
5356
end
5457
@test AD.jacobian(ad, f, [1, 2, 3], 3) ==
5558
([6.0 0.0 0.0; 0.0 6.0 0.0; 0.0 0.0 6.0], [2.0, 4.0, 6.0])
59+
60+
# With `AD.ZygoteBackend`:
61+
ad = AD.ZygoteBackend()
62+
@test AD.derivative(ad, exp, 1.0) === (exp(1.0),)
63+
@test AD.derivative(ad, exp, 1.0) === (exp(1.0),)
64+
@test AD.jacobian(ad, f, [1, 2, 3], 3) ==
65+
([6.0 0.0 0.0; 0.0 6.0 0.0; 0.0 0.0 6.0], [2.0, 4.0, 6.0])
5666
end
5767

5868
# issue #57
@@ -65,5 +75,17 @@ using Zygote
6575

6676
@test_logs Zygote.gradient(myfunc, 1) # nothing is logged
6777
@test_logs AD.derivative(AD.ZygoteBackend(), myfunc, 1) # nothing is logged
78+
@test_logs AD.derivative(
79+
AD.ReverseRuleConfigBackend(Zygote.ZygoteRuleConfig()), myfunc, 1
80+
) # nothing is logged
81+
end
82+
83+
# issue #54
84+
@testset "allocations of jacobian" begin
85+
f(x) = x .^ 2
86+
x = rand(100)
87+
ad = AD.ZygoteBackend()
88+
@test AD.jacobian(ad, f, x) == Zygote.jacobian(f, x)
89+
@test @allocated(AD.jacobian(ad, f, x)) == @allocated(Zygote.jacobian(f, x))
6890
end
6991
end

0 commit comments

Comments
 (0)