From 19bed2687dafbdf7254ee35496f2ef537a0e7497 Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 1 Mar 2024 13:48:10 +0800 Subject: [PATCH 1/3] Disable + and - rules on mixed numeric types --- src/rulesets/Base/fastmath_able.jl | 50 ++++++++++++++++++++++++----- test/rulesets/Base/fastmath_able.jl | 30 ++++++++++++++--- 2 files changed, 68 insertions(+), 12 deletions(-) diff --git a/src/rulesets/Base/fastmath_able.jl b/src/rulesets/Base/fastmath_able.jl index 8f84cee03..39cb500d8 100644 --- a/src/rulesets/Base/fastmath_able.jl +++ b/src/rulesets/Base/fastmath_able.jl @@ -164,19 +164,53 @@ let return (Ω, hypot_pullback) end - @scalar_rule x + y (true, true) - @scalar_rule x - y (true, -1) - @scalar_rule x / y (one(x) / y, -(Ω / y)) - - ## many-arg + - function frule(Δs, ::typeof(+), x::Number, ys::Number...) + ### + ### + + ### + # Same type so must have same tangent type + function frule(Δs, ::typeof(+), x::T, ys::T...) where {T<:Number} +(x, ys...), +(Base.tail(Δs)...) end - function rrule(::typeof(+), x::Number, ys::Number...) + function rrule(::typeof(+), x::T, ys::T...) where {T<:Number} plus_back(dz) = (NoTangent(), dz, map(Returns(dz), ys)...) +(x, ys...), plus_back - end + end + + #Or Tangent type is same as primal type so op must be defined on it too + function frule((_, ẋ, ẏ)::Tuple{<:Any, A, B}, ::typeof(+), x::A, y::B) where {A<:Number, B<:Number} + return +(x, y), +(ẋ, ẏ) + end + + # Both cases (break ambiguity) + function frule((_, ẋ, ẏ)::Tuple{<:Any, T, T}, ::typeof(+), x::T, y::T) where {T<:Number} + return +(x, y), +(ẋ, ẏ) + end + + + ### + ### - + ### + # Same type so must have same tangent type + function rrule(::typeof(-), x::T, y::T) where {T<:Number} + minus_pullback(z̄) = NoTangent(), z̄, -(z̄) + return x - y, minus_pullback + end + frule((_, ẋ, ẏ), ::typeof(-), x::T, y::T) where {T<:Number} = -(x, y), -(ẋ, ẏ) + + #Or Tangent type is same as primal type so op must be defined on it too + function frule((_, ẋ, ẏ)::Tuple{<:Any, A, B}, ::typeof(-), x::A, y::B) where {A<:Number, B<:Number} + return -(x, y), -(ẋ, ẏ) + end + + # Both cases (break ambiguity) + function frule((_, ẋ, ẏ)::Tuple{<:Any, T, T}, ::typeof(-), x::T, y::T) where {T<:Number} + return -(x, y), -(ẋ, ẏ) + end + + + @scalar_rule x / y (one(x) / y, -(Ω / y)) + ## power # literal_pow is in base.jl diff --git a/test/rulesets/Base/fastmath_able.jl b/test/rulesets/Base/fastmath_able.jl index 45ad33cc3..a35b75a69 100644 --- a/test/rulesets/Base/fastmath_able.jl +++ b/test/rulesets/Base/fastmath_able.jl @@ -144,17 +144,17 @@ const FASTABLE_AST = quote @assert T == typeof(f(x, y)) Δz = randn(typeof(f(x, y))) - @test frule((ZeroTangent(), Δx, Δy), f, x, y) isa Tuple{T, T} + @test frule((NoTangent(), Δx, Δy), f, x, y) isa Tuple{T, T} _, ∂x, ∂y = rrule(f, x, y)[2](Δz) @test (∂x, ∂y) isa Tuple{T, T} - if f != hypot + if f ∉ (hypot, +, -) # Issue #233 - @test frule((ZeroTangent(), Δx, Δy), f, x, 2) isa Tuple{T, T} + @test frule((NoTangent(), Δx, Δy), f, x, 2) isa Tuple{T, T} _, ∂x, ∂y = rrule(f, x, 2)[2](Δz) @test (∂x, ∂y) isa Tuple{T, Float64} - @test frule((ZeroTangent(), Δx, Δy), f, 2, y) isa Tuple{T, T} + @test frule((NoTangent(), Δx, Δy), f, 2, y) isa Tuple{T, T} _, ∂x, ∂y = rrule(f, 2, y)[2](Δz) @test (∂x, ∂y) isa Tuple{Float64, T} end @@ -283,6 +283,28 @@ const FASTABLE_AST = quote end end end + + @testset "+,- on weird types" begin + + struct StoreHalfed <: Number + val::Float64 + StoreHalfed(x) = new(x/2) + end + Base.:-(x::StoreHalfed, y::Number) = 2*x.val - y + Base.:+(x::StoreHalfed, y::Number) = 2*x.val + y + + sh1 = StoreHalfed(4.0) + sh2 = StoreHalfed(8.0) + f1 = 40.0 + f2 = 80.0 + + # We have had issues with mixed number types before + # So these should not hit + @test rrule(+, sh1, f1) == nothing + @test rrule(-, sh1, f1) == nothing + @test frule((NoTangent(), Tangent{StoreHalfed}(val=2.0), 20.0),+, sh1, f1) == nothing + @test frule((NoTangent(), Tangent{StoreHalfed}(val=2.0), 20.0),-, sh1, f1) == nothing + end end # Now we generate tests for fast and nonfast versions From f74d9aa303a10ba457dd6e7f51db4a54e0b6ba52 Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 1 Mar 2024 14:20:46 +0800 Subject: [PATCH 2/3] formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/rulesets/Base/fastmath_able.jl | 22 +++++++++++++++------- test/rulesets/Base/fastmath_able.jl | 15 +++++++-------- 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/src/rulesets/Base/fastmath_able.jl b/src/rulesets/Base/fastmath_able.jl index 39cb500d8..fa2a4fb25 100644 --- a/src/rulesets/Base/fastmath_able.jl +++ b/src/rulesets/Base/fastmath_able.jl @@ -178,15 +178,18 @@ let end #Or Tangent type is same as primal type so op must be defined on it too - function frule((_, ẋ, ẏ)::Tuple{<:Any, A, B}, ::typeof(+), x::A, y::B) where {A<:Number, B<:Number} + function frule( + (_, ẋ, ẏ)::Tuple{<:Any,A,B}, ::typeof(+), x::A, y::B + ) where {A<:Number,B<:Number} return +(x, y), +(ẋ, ẏ) end # Both cases (break ambiguity) - function frule((_, ẋ, ẏ)::Tuple{<:Any, T, T}, ::typeof(+), x::T, y::T) where {T<:Number} + function frule( + (_, ẋ, ẏ)::Tuple{<:Any,T,T}, ::typeof(+), x::T, y::T + ) where {T<:Number} return +(x, y), +(ẋ, ẏ) end - ### ### - @@ -196,18 +199,23 @@ let minus_pullback(z̄) = NoTangent(), z̄, -(z̄) return x - y, minus_pullback end - frule((_, ẋ, ẏ), ::typeof(-), x::T, y::T) where {T<:Number} = -(x, y), -(ẋ, ẏ) + function frule((_, ẋ, ẏ), ::typeof(-), x::T, y::T) where {T<:Number} + return -(x, y), -(ẋ, ẏ) + end #Or Tangent type is same as primal type so op must be defined on it too - function frule((_, ẋ, ẏ)::Tuple{<:Any, A, B}, ::typeof(-), x::A, y::B) where {A<:Number, B<:Number} + function frule( + (_, ẋ, ẏ)::Tuple{<:Any,A,B}, ::typeof(-), x::A, y::B + ) where {A<:Number,B<:Number} return -(x, y), -(ẋ, ẏ) end # Both cases (break ambiguity) - function frule((_, ẋ, ẏ)::Tuple{<:Any, T, T}, ::typeof(-), x::T, y::T) where {T<:Number} + function frule( + (_, ẋ, ẏ)::Tuple{<:Any,T,T}, ::typeof(-), x::T, y::T + ) where {T<:Number} return -(x, y), -(ẋ, ẏ) end - @scalar_rule x / y (one(x) / y, -(Ω / y)) diff --git a/test/rulesets/Base/fastmath_able.jl b/test/rulesets/Base/fastmath_able.jl index a35b75a69..415d109b8 100644 --- a/test/rulesets/Base/fastmath_able.jl +++ b/test/rulesets/Base/fastmath_able.jl @@ -144,17 +144,17 @@ const FASTABLE_AST = quote @assert T == typeof(f(x, y)) Δz = randn(typeof(f(x, y))) - @test frule((NoTangent(), Δx, Δy), f, x, y) isa Tuple{T, T} + @test frule((NoTangent(), Δx, Δy), f, x, y) isa Tuple{T,T} _, ∂x, ∂y = rrule(f, x, y)[2](Δz) @test (∂x, ∂y) isa Tuple{T, T} if f ∉ (hypot, +, -) # Issue #233 - @test frule((NoTangent(), Δx, Δy), f, x, 2) isa Tuple{T, T} + @test frule((NoTangent(), Δx, Δy), f, x, 2) isa Tuple{T,T} _, ∂x, ∂y = rrule(f, x, 2)[2](Δz) @test (∂x, ∂y) isa Tuple{T, Float64} - @test frule((NoTangent(), Δx, Δy), f, 2, y) isa Tuple{T, T} + @test frule((NoTangent(), Δx, Δy), f, 2, y) isa Tuple{T,T} _, ∂x, ∂y = rrule(f, 2, y)[2](Δz) @test (∂x, ∂y) isa Tuple{Float64, T} end @@ -285,14 +285,13 @@ const FASTABLE_AST = quote end @testset "+,- on weird types" begin - struct StoreHalfed <: Number val::Float64 - StoreHalfed(x) = new(x/2) + StoreHalfed(x) = new(x / 2) end - Base.:-(x::StoreHalfed, y::Number) = 2*x.val - y - Base.:+(x::StoreHalfed, y::Number) = 2*x.val + y - + Base.:-(x::StoreHalfed, y::Number) = 2 * x.val - y + Base.:+(x::StoreHalfed, y::Number) = 2 * x.val + y + sh1 = StoreHalfed(4.0) sh2 = StoreHalfed(8.0) f1 = 40.0 From a70495c295a89c7366b051e6c0a0fdb8a968ffa1 Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 1 Mar 2024 14:21:30 +0800 Subject: [PATCH 3/3] autoformat Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/rulesets/Base/fastmath_able.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/rulesets/Base/fastmath_able.jl b/test/rulesets/Base/fastmath_able.jl index 415d109b8..825c579f8 100644 --- a/test/rulesets/Base/fastmath_able.jl +++ b/test/rulesets/Base/fastmath_able.jl @@ -301,8 +301,10 @@ const FASTABLE_AST = quote # So these should not hit @test rrule(+, sh1, f1) == nothing @test rrule(-, sh1, f1) == nothing - @test frule((NoTangent(), Tangent{StoreHalfed}(val=2.0), 20.0),+, sh1, f1) == nothing - @test frule((NoTangent(), Tangent{StoreHalfed}(val=2.0), 20.0),-, sh1, f1) == nothing + @test frule((NoTangent(), Tangent{StoreHalfed}(; val=2.0), 20.0), +, sh1, f1) == + nothing + @test frule((NoTangent(), Tangent{StoreHalfed}(; val=2.0), 20.0), -, sh1, f1) == + nothing end end