Skip to content

Fix #99 #102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions src/AbstractDifferentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,16 +163,26 @@ function pushforward_function(
xs...,
)
return (ds) -> begin
if ds isa Tuple
if length(xs) != length(ds)
throw(ArgumentError("The input and tangents are not of compatible sizes."))
end
z = _zero.(xs, ds)
else
if length(xs) != 1
throw(ArgumentError("The input and tangents are not of compatible sizes."))
end
z = _zero.(xs, (ds,))
end
return jacobian(lowest(ab), (xds...,) -> begin
if ds isa Tuple
@assert length(xs) == length(ds)
newxs = xs .+ ds .* xds
return f(newxs...)
else
newx = only(xs) + ds * only(xds)
return f(newx)
end
end, _zero.(xs, ds)...)
end, z...)
end
end
function value_and_pushforward_function(
Expand All @@ -184,7 +194,9 @@ function value_and_pushforward_function(
if !(ds isa Tuple)
ds = (ds,)
end
@assert length(ds) == length(xs)
if length(ds) != length(xs)
throw(ArgumentError("The input and tangents are not of compatible sizes."))
end
local value
primalcalled = false
if ab isa AbstractFiniteDifference
Expand Down Expand Up @@ -225,14 +237,27 @@ function pullback_function(ab::AbstractBackend, f, xs...)
return gradient(lowest(ab), (xs...,) -> begin
vs = f(xs...)
if ws isa Tuple
@assert length(vs) == length(ws)
return sum(Base.splat(_dot), zip(ws, vs))
if vs isa Tuple
if length(vs) != length(ws)
throw(ArgumentError("The output and cotangents are not of compatible sizes."))
end
return sum(Base.splat(_dot), zip(ws, vs))
else
if 1 != length(ws)
throw(ArgumentError("The output and cotangents are not of compatible sizes."))
end
return _dot(vs, only(ws))
end
else
if vs isa Tuple
throw(ArgumentError("The output and cotangents are not of compatible sizes."))
end
return _dot(vs, ws)
end
end, xs...)
end
end

function value_and_pullback_function(
ab::AbstractBackend,
f,
Expand Down
30 changes: 30 additions & 0 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,19 @@ function test_jvp(backend; multiple_inputs=true, vaugmented=false, rng=Random.GL
end

valvec1, pf1 = AD.value_and_pushforward_function(backend, x -> fjac(x, yvec), xvec)(v[1])
_valvec1, _pf1 = AD.value_and_pushforward_function(backend, x -> fjac(x, yvec), xvec)((v[1],))
@test valvec1 == _valvec1
@test pf1 == _pf1

@test_throws MethodError AD.value_and_pushforward_function(backend, x -> fjac(x, yvec), xvec)(v[1], nothing) # 1 input, 2 plain tangents
@test_throws ArgumentError AD.value_and_pushforward_function(backend, x -> fjac(x, yvec), xvec)((v[1], nothing)) # 1 input, 2 tuple tangents
@test_throws ArgumentError AD.value_and_pushforward_function(backend, (x, _) -> fjac(x, yvec), xvec, nothing)(v[1]) # 2 inputs, 1 plain tangent
@test_throws ArgumentError AD.value_and_pushforward_function(backend, (x, _) -> fjac(x, yvec), xvec, nothing)((v[1],)) # 2 inputs, 1 tuple tangent

valvec2, pf2 = AD.value_and_pushforward_function(backend, y -> fjac(xvec, y), yvec)(v[2])
_valvec2, _pf2 = AD.value_and_pushforward_function(backend, y -> fjac(xvec, y), yvec)((v[2],))
@test valvec2 == _valvec2
@test pf2 == _pf2

if test_types
@test valvec1 isa Vector{Float64}
Expand All @@ -247,7 +259,17 @@ function test_j′vp(backend; multiple_inputs=true, rng=Random.GLOBAL_RNG, test_
w = rand(rng, length(fjac(xvec, yvec)))
if multiple_inputs
pb1 = AD.pullback_function(backend, fjac, xvec, yvec)(w)
_pb1 = AD.pullback_function(backend, fjac, xvec, yvec)((w,))
@test pb1 == _pb1

@test_throws MethodError AD.pullback_function(backend, fjac, xvec, yvec)(w, nothing) # 1 output, 2 plain cotangents
@test_throws ArgumentError AD.pullback_function(backend, fjac, xvec, yvec)((w, nothing)) # 1 output, 2 tuple cotangents
# TODO: how to test with 2 outputs and 1 cotangent?

valvec, pb2 = AD.value_and_pullback_function(backend, fjac, xvec, yvec)(w)
_valvec, _pb2 = AD.value_and_pullback_function(backend, fjac, xvec, yvec)((w,))
@test valvec == _valvec
@test pb2 == _pb2

if test_types
@test valvec isa Vector{Float64}
Expand All @@ -264,7 +286,15 @@ function test_j′vp(backend; multiple_inputs=true, rng=Random.GLOBAL_RNG, test_
end

valvec1, pb1 = AD.value_and_pullback_function(backend, x -> fjac(x, yvec), xvec)(w)
_valvec1, _pb1 = AD.value_and_pullback_function(backend, x -> fjac(x, yvec), xvec)((w,))
@test valvec1 == _valvec1
@test pb1 == _pb1

valvec2, pb2 = AD.value_and_pullback_function(backend, y -> fjac(xvec, y), yvec)(w)
_valvec2, _pb2 = AD.value_and_pullback_function(backend, y -> fjac(xvec, y), yvec)((w,))
@test valvec2 == _valvec2
@test pb2 == _pb2

if test_types
@test valvec1 isa Vector{Float64}
@test valvec2 isa Vector{Float64}
Expand Down