Skip to content

Commit e7c8abd

Browse files
authored
Merge pull request #224 from JuliaDiff/ox/mapfix
Fix map in forwards mode
2 parents b3e4ee0 + fb1dd92 commit e7c8abd

File tree

3 files changed

+41
-8
lines changed

3 files changed

+41
-8
lines changed

src/tangent.jl

+15-5
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ end
9494
coeffs::C
9595
TaylorTangent(coeffs) = $(Expr(:new, :(TaylorTangent{typeof(coeffs)}), :coeffs))
9696
end
97+
Base.:(==)(a::TaylorTangent, b::TaylorTangent) = a.coeffs == b.coeffs
98+
Base.hash(tt::TaylorTangent, h::UInt64) = hash(tt.coeffs, h)
9799

98100
"""
99101
struct TaylorTangent{C}
@@ -159,6 +161,9 @@ TangentBundle
159161
TangentBundle{N}(primal::B, tangent::P) where {N, B, P<:AbstractTangentSpace} =
160162
_TangentBundle(Val{N}(), primal, tangent)
161163

164+
Base.hash(tb::TangentBundle, h::UInt64) = hash(tb.primal, h)
165+
Base.:(==)(a::TangentBundle, b::TangentBundle) = (a.primal == b.primal) && (a.tangent == b.tangent)
166+
162167
const ExplicitTangentBundle{N, B, P} = TangentBundle{N, B, ExplicitTangent{P}}
163168

164169
check_tangent_invariant(lp, N) = @assert lp == 2^N - 1
@@ -201,20 +206,25 @@ end
201206

202207
const TaylorBundle{N, B, P} = TangentBundle{N, B, TaylorTangent{P}}
203208

209+
210+
function TaylorBundle{N, B, P}(primal::B, coeffs::P) where {N, B, P}
211+
check_taylor_invariants(coeffs, primal, N)
212+
_TangentBundle(Val{N}(), primal, TaylorTangent(coeffs))
213+
end
204214
function TaylorBundle{N, B}(primal::B, coeffs) where {N, B}
205215
check_taylor_invariants(coeffs, primal, N)
206216
_TangentBundle(Val{N}(), primal, TaylorTangent(coeffs))
207217
end
218+
function TaylorBundle{N}(primal, coeffs) where {N}
219+
check_taylor_invariants(coeffs, primal, N)
220+
_TangentBundle(Val{N}(), primal, TaylorTangent(coeffs))
221+
end
208222

209223
function check_taylor_invariants(coeffs, primal, N)
210224
@assert length(coeffs) == N
211-
212225
end
213226
@ChainRulesCore.non_differentiable check_taylor_invariants(coeffs, primal, N)
214227

215-
function TaylorBundle{N}(primal, coeffs) where {N}
216-
_TangentBundle(Val{N}(), primal, TaylorTangent(coeffs))
217-
end
218228

219229
function Base.show(io::IO, x::TaylorBundle{1})
220230
print(io, x.primal)
@@ -350,7 +360,7 @@ function ChainRulesCore.rrule(::typeof(unbundle), atb::AbstractTangentBundle)
350360
unbundle(atb), Δ->throw(Δ)
351361
end
352362

353-
function StructArrays.createinstance(T::Type{<:ZeroBundle}, args...)
363+
function StructArrays.createinstance(T::Type{<:UniformBundle}, args...)
354364
T(args[1], args[2])
355365
end
356366

test/forward.jl

+21-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module forward_tests
22
using Diffractor
3-
using Diffractor: TaylorBundle, ZeroBundle
3+
using Diffractor: TaylorBundle, ZeroBundle, ∂☆
44
using ChainRules
55
using ChainRulesCore
66
using ChainRulesCore: ZeroTangent, NoTangent, frule_via_ad, rrule_via_ad
@@ -61,7 +61,7 @@ end
6161
end
6262

6363
# Special case if there is no derivative information at all:
64-
@test (Diffractor.∂☆{1}())(ZeroBundle{1}(foo), ZeroBundle{1}(2.0), ZeroBundle{1}(3.0)) == ZeroBundle{1}(5.0)
64+
@test ∂☆{1}()(ZeroBundle{1}(foo), ZeroBundle{1}(2.0), ZeroBundle{1}(3.0)) == ZeroBundle{1}(5.0)
6565
@test frule_calls[] == 0
6666
@test primal_calls[] == 1
6767
end
@@ -88,6 +88,24 @@ end
8888
end
8989

9090

91+
@testset "map" begin
92+
@test ==(
93+
∂☆{1}()(ZeroBundle{1}(xs->(map(x->2*x, xs))), TaylorBundle{1}([1.0, 2.0], ([10.0, 100.0],))),
94+
TaylorBundle{1}([2.0, 4.0], ([20.0, 200.0],))
95+
)
96+
97+
98+
# map over all closure, wrt the closed variable
99+
mulby(x) = y->x*y
100+
🐇 = ∂☆{1}()(
101+
ZeroBundle{1}(x->(map(mulby(x), [2.0, 4.0]))),
102+
TaylorBundle{1}(2.0, (10.0,))
103+
)
104+
@test 🐇 == TaylorBundle{1}([4.0, 8.0], ([20.0, 40.0],))
105+
106+
end
107+
108+
91109
@testset "structs" begin
92110
struct IDemo
93111
x::Float64
@@ -166,4 +184,4 @@ end
166184
)
167185
end
168186

169-
end
187+
end # module

test/tangent.jl

+5
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@ end
4646
end
4747
end
4848

49+
@testset "== and hash" begin
50+
@test TaylorBundle{1}([2.0, 4.0], ([20.0, 200.0],)) == TaylorBundle{1}([2.0, 4.0], ([20.0, 200.0],))
51+
@test hash(TaylorBundle{1}(0.0, (0.0,))) == hash(0)
52+
end
53+
4954
@testset "truncate" begin
5055
tt = TaylorTangent((1.0,2.0,3.0,4.0,5.0,6.0,7.0))
5156
@test truncate(tt, Val(2)) == TaylorTangent((1.0,2.0))

0 commit comments

Comments
 (0)