From 81736ce0cbd44b36cc30c12ee96f9c9737b477bd Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Tue, 9 Aug 2022 17:06:55 -0700 Subject: [PATCH 1/2] project ZeroTangent to natural tangent for some number types --- src/projection.jl | 1 + test/projection.jl | 8 ++++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index 8eba26353..fa4c4557b 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -191,6 +191,7 @@ end # understands, including a mix of Zeros & reals. Other cases, we just let through: (project::ProjectTo{<:Number})(dx::Tangent{<:Complex}) = project(Complex(dx.re, dx.im)) (::ProjectTo{<:Number})(dx::Tangent{<:Number}) = dx +(::ProjectTo{T})(::ZeroTangent) where {T<:Real} = zero(T) # Arrays # If we don't have a more specialized `ProjectTo` rule, we just assume that there is diff --git a/test/projection.jl b/test/projection.jl index 3e70772ac..d1e181da6 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -37,10 +37,10 @@ struct NoSuperType end @test ProjectTo(1.0)(2) === 2.0 # Tangents - ProjectTo(1.0f0 + 2im)(Tangent{ComplexF64}(; re=1, im=NoTangent())) === - 1.0f0 + 0.0f0im - - @test 1.0 === ProjectTo(1.0)(Tangent{ComplexF64}(; re=1, im=NoTangent())) + complex_tangent = Tangent{ComplexF64}(; re=1, im=NoTangent()) + @test ProjectTo(1.0f0 + 2im)(complex_tangent) === 1.0f0 + 0.0f0im + @test ProjectTo(1.0)(complex_tangent) === 1.0 + @test ProjectTo(1.0)(ZeroTangent()) === 0.0 end @testset "Dual" begin # some weird Real subtype that we should basically leave alone From 54fcbed1bcf89e13690b54763dd5ece2a45caf41 Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Tue, 9 Aug 2022 18:19:52 -0700 Subject: [PATCH 2/2] Tweak tests to match wrt. natural tangents --- test/projection.jl | 3 +-- test/rules.jl | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/test/projection.jl b/test/projection.jl index d1e181da6..771b971b4 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -212,7 +212,7 @@ struct NoSuperType end @test ProjectTo(I)(123) === NoTangent() @test ProjectTo(2 * I)(I * 3im) === 0.0 * I @test ProjectTo((4 + 5im) * I)(Tangent{typeof(im * I)}(; λ = 6)) === (6.0 + 0.0im) * I - @test ProjectTo(7 * I)(Tangent{typeof(2I)}()) == ZeroTangent() + @test ProjectTo(7 * I)(Tangent{typeof(2I)}()) == 0.0I end @testset "LinearAlgebra: $adj vectors" for adj in [transpose, adjoint] @@ -413,7 +413,6 @@ struct NoSuperType end @test pb(ZeroTangent()) isa AbstractZero # was a method ambiguity! # all projectors preserve Zero, and specific type, via one fallback method: - @test ProjectTo(pi)(ZeroTangent()) === ZeroTangent() @test ProjectTo(pi)(NoTangent()) === NoTangent() pv = ProjectTo(sprand(30, 0.3)) @test pv(ZeroTangent()) === ZeroTangent() diff --git a/test/rules.jl b/test/rules.jl index 54c10b160..a44c18543 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -94,7 +94,7 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) frx, nice_pushforward = frule((dself, 1), nice, 1) @test nice_pushforward === ZeroTangent() rrx, nice_pullback = rrule(nice, 1) - @test (NoTangent(), ZeroTangent()) === nice_pullback(1) + @test (NoTangent(), 0.0) === nice_pullback(1) # Test that these run. Do not care about numerical correctness. @test frule((nothing, 1.0, 1.0, 1.0), varargs_function, 0.5, 0.5, 0.5) == (1.5, 3.0)