From 602c23b5ab17799f58d3aedbca7b34bdf5c551aa Mon Sep 17 00:00:00 2001 From: Mohamed Tarek Date: Mon, 7 Aug 2023 05:36:46 +1000 Subject: [PATCH 1/3] fix #99 --- src/AbstractDifferentiation.jl | 7 +++++-- test/test_utils.jl | 21 +++++++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/src/AbstractDifferentiation.jl b/src/AbstractDifferentiation.jl index 144f382..5b01284 100644 --- a/src/AbstractDifferentiation.jl +++ b/src/AbstractDifferentiation.jl @@ -163,6 +163,7 @@ function pushforward_function( xs..., ) return (ds) -> begin + z = (ds isa Tuple ? _zero.(xs, ds) : _zero.(xs, (ds,))) return jacobian(lowest(ab), (xds...,) -> begin if ds isa Tuple @assert length(xs) == length(ds) @@ -172,7 +173,7 @@ function pushforward_function( newx = only(xs) + ds * only(xds) return f(newx) end - end, _zero.(xs, ds)...) + end, z...) end end function value_and_pushforward_function( @@ -224,9 +225,11 @@ function pullback_function(ab::AbstractBackend, f, xs...) return (ws) -> begin return gradient(lowest(ab), (xs...,) -> begin vs = f(xs...) - if ws isa Tuple + if ws isa Tuple && length(ws) > 1 @assert length(vs) == length(ws) return sum(Base.splat(_dot), zip(ws, vs)) + elseif ws isa Tuple && length(ws) == 1 + return _dot(vs, only(ws)) else return _dot(vs, ws) end diff --git a/test/test_utils.jl b/test/test_utils.jl index 3711fcd..54b2722 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -229,7 +229,14 @@ function test_jvp(backend; multiple_inputs=true, vaugmented=false, rng=Random.GL end valvec1, pf1 = AD.value_and_pushforward_function(backend, x -> fjac(x, yvec), xvec)(v[1]) + _valvec1, _pf1 = AD.value_and_pushforward_function(backend, x -> fjac(x, yvec), xvec)((v[1],)) + @test valvec1 == _valvec1 + @test pf1 == _pf1 + valvec2, pf2 = AD.value_and_pushforward_function(backend, y -> fjac(xvec, y), yvec)(v[2]) + _valvec2, _pf2 = AD.value_and_pushforward_function(backend, y -> fjac(xvec, y), yvec)((v[2],)) + @test valvec2 == _valvec2 + @test pf2 == _pf2 if test_types @test valvec1 isa Vector{Float64} @@ -247,7 +254,13 @@ function test_j′vp(backend; multiple_inputs=true, rng=Random.GLOBAL_RNG, test_ w = rand(rng, length(fjac(xvec, yvec))) if multiple_inputs pb1 = AD.pullback_function(backend, fjac, xvec, yvec)(w) + _pb1 = AD.pullback_function(backend, fjac, xvec, yvec)((w,)) + @test pb1 == _pb1 + valvec, pb2 = AD.value_and_pullback_function(backend, fjac, xvec, yvec)(w) + _valvec, _pb2 = AD.value_and_pullback_function(backend, fjac, xvec, yvec)((w,)) + @test valvec == _valvec + @test pb2 == _pb2 if test_types @test valvec isa Vector{Float64} @@ -264,7 +277,15 @@ function test_j′vp(backend; multiple_inputs=true, rng=Random.GLOBAL_RNG, test_ end valvec1, pb1 = AD.value_and_pullback_function(backend, x -> fjac(x, yvec), xvec)(w) + _valvec1, _pb1 = AD.value_and_pullback_function(backend, x -> fjac(x, yvec), xvec)((w,)) + @test valvec1 == _valvec1 + @test pb1 == _pb1 + valvec2, pb2 = AD.value_and_pullback_function(backend, y -> fjac(xvec, y), yvec)(w) + _valvec2, _pb2 = AD.value_and_pullback_function(backend, y -> fjac(xvec, y), yvec)((w,)) + @test valvec2 == _valvec2 + @test pb2 == _pb2 + if test_types @test valvec1 isa Vector{Float64} @test valvec2 isa Vector{Float64} From 4ec54502fd8d070867b1e07af3f6e78eead6989f Mon Sep 17 00:00:00 2001 From: Mohamed Tarek Date: Mon, 7 Aug 2023 05:44:24 +1000 Subject: [PATCH 2/3] improve fix --- src/AbstractDifferentiation.jl | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/AbstractDifferentiation.jl b/src/AbstractDifferentiation.jl index 5b01284..583b89a 100644 --- a/src/AbstractDifferentiation.jl +++ b/src/AbstractDifferentiation.jl @@ -163,10 +163,17 @@ function pushforward_function( xs..., ) return (ds) -> begin - z = (ds isa Tuple ? _zero.(xs, ds) : _zero.(xs, (ds,))) + if ds isa Tuple + @assert length(xs) == length(ds) + z = _zero.(xs, ds) + elseif length(xs) == 1 + z = _zero.(xs, (ds,)) + else + z = 0 + throw(ArgumentError("The input and tangents are not of compatible sizes.")) + end return jacobian(lowest(ab), (xds...,) -> begin if ds isa Tuple - @assert length(xs) == length(ds) newxs = xs .+ ds .* xds return f(newxs...) else @@ -225,7 +232,7 @@ function pullback_function(ab::AbstractBackend, f, xs...) return (ws) -> begin return gradient(lowest(ab), (xs...,) -> begin vs = f(xs...) - if ws isa Tuple && length(ws) > 1 + if ws isa Tuple && vs isa Tuple @assert length(vs) == length(ws) return sum(Base.splat(_dot), zip(ws, vs)) elseif ws isa Tuple && length(ws) == 1 From 7f53a38c880cb90dc0bf67dcf494b75aa96df096 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 7 Aug 2023 00:27:13 +0200 Subject: [PATCH 3/3] Add error tests for pushforward and pullback, failing for finitedifferences --- src/AbstractDifferentiation.jl | 37 ++++++++++++++++++++++++---------- test/test_utils.jl | 9 +++++++++ 2 files changed, 35 insertions(+), 11 deletions(-) diff --git a/src/AbstractDifferentiation.jl b/src/AbstractDifferentiation.jl index 583b89a..5cf15d2 100644 --- a/src/AbstractDifferentiation.jl +++ b/src/AbstractDifferentiation.jl @@ -164,13 +164,15 @@ function pushforward_function( ) return (ds) -> begin if ds isa Tuple - @assert length(xs) == length(ds) + if length(xs) != length(ds) + throw(ArgumentError("The input and tangents are not of compatible sizes.")) + end z = _zero.(xs, ds) - elseif length(xs) == 1 - z = _zero.(xs, (ds,)) else - z = 0 - throw(ArgumentError("The input and tangents are not of compatible sizes.")) + if length(xs) != 1 + throw(ArgumentError("The input and tangents are not of compatible sizes.")) + end + z = _zero.(xs, (ds,)) end return jacobian(lowest(ab), (xds...,) -> begin if ds isa Tuple @@ -192,7 +194,9 @@ function value_and_pushforward_function( if !(ds isa Tuple) ds = (ds,) end - @assert length(ds) == length(xs) + if length(ds) != length(xs) + throw(ArgumentError("The input and tangents are not of compatible sizes.")) + end local value primalcalled = false if ab isa AbstractFiniteDifference @@ -232,17 +236,28 @@ function pullback_function(ab::AbstractBackend, f, xs...) return (ws) -> begin return gradient(lowest(ab), (xs...,) -> begin vs = f(xs...) - if ws isa Tuple && vs isa Tuple - @assert length(vs) == length(ws) - return sum(Base.splat(_dot), zip(ws, vs)) - elseif ws isa Tuple && length(ws) == 1 - return _dot(vs, only(ws)) + if ws isa Tuple + if vs isa Tuple + if length(vs) != length(ws) + throw(ArgumentError("The output and cotangents are not of compatible sizes.")) + end + return sum(Base.splat(_dot), zip(ws, vs)) + else + if 1 != length(ws) + throw(ArgumentError("The output and cotangents are not of compatible sizes.")) + end + return _dot(vs, only(ws)) + end else + if vs isa Tuple + throw(ArgumentError("The output and cotangents are not of compatible sizes.")) + end return _dot(vs, ws) end end, xs...) end end + function value_and_pullback_function( ab::AbstractBackend, f, diff --git a/test/test_utils.jl b/test/test_utils.jl index 54b2722..01d10e8 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -233,6 +233,11 @@ function test_jvp(backend; multiple_inputs=true, vaugmented=false, rng=Random.GL @test valvec1 == _valvec1 @test pf1 == _pf1 + @test_throws MethodError AD.value_and_pushforward_function(backend, x -> fjac(x, yvec), xvec)(v[1], nothing) # 1 input, 2 plain tangents + @test_throws ArgumentError AD.value_and_pushforward_function(backend, x -> fjac(x, yvec), xvec)((v[1], nothing)) # 1 input, 2 tuple tangents + @test_throws ArgumentError AD.value_and_pushforward_function(backend, (x, _) -> fjac(x, yvec), xvec, nothing)(v[1]) # 2 inputs, 1 plain tangent + @test_throws ArgumentError AD.value_and_pushforward_function(backend, (x, _) -> fjac(x, yvec), xvec, nothing)((v[1],)) # 2 inputs, 1 tuple tangent + valvec2, pf2 = AD.value_and_pushforward_function(backend, y -> fjac(xvec, y), yvec)(v[2]) _valvec2, _pf2 = AD.value_and_pushforward_function(backend, y -> fjac(xvec, y), yvec)((v[2],)) @test valvec2 == _valvec2 @@ -257,6 +262,10 @@ function test_j′vp(backend; multiple_inputs=true, rng=Random.GLOBAL_RNG, test_ _pb1 = AD.pullback_function(backend, fjac, xvec, yvec)((w,)) @test pb1 == _pb1 + @test_throws MethodError AD.pullback_function(backend, fjac, xvec, yvec)(w, nothing) # 1 output, 2 plain cotangents + @test_throws ArgumentError AD.pullback_function(backend, fjac, xvec, yvec)((w, nothing)) # 1 output, 2 tuple cotangents + # TODO: how to test with 2 outputs and 1 cotangent? + valvec, pb2 = AD.value_and_pullback_function(backend, fjac, xvec, yvec)(w) _valvec, _pb2 = AD.value_and_pullback_function(backend, fjac, xvec, yvec)((w,)) @test valvec == _valvec