From 3d0ae093548adfe763806645ff5e5f5875f41796 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Tue, 14 Mar 2023 22:28:50 +0100 Subject: [PATCH 01/10] Add internal function `_reverse` and overloads --- src/lib/array.jl | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index 1c6e09916..ef28d145c 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -2,6 +2,8 @@ using Random, FillArrays, AbstractFFTs using FillArrays: AbstractFill, getindex_value using Base.Broadcast: broadcasted, broadcast_shape using Distributed: pmap, AbstractWorkerPool +using LinearAlgebra: Diagonal, Hermitian, LowerTriangular, UpperTriangular +using LinearAlgebra: UnitLowerTriangular, UnitUpperTriangular @adjoint Array(xs::AbstractArray) = Array(xs), ȳ -> (ȳ,) @adjoint Array(xs::Array) = Array(xs), ȳ -> (ȳ,) @@ -165,10 +167,21 @@ end # This is also used by comprehensions, which do guarantee iteration order. # Not done for pmap, presumably because all is lost if you are relying on its order. _tryreverse(m, backs, Δ) = backs, Δ -_tryreverse(m::typeof(map), backs, Δ) = reverse(backs), reverse(Δ) +_tryreverse(m::typeof(map), backs, Δ) = _reverse(backs), _reverse(Δ) _tryreverse(m, x) = x -_tryreverse(m::typeof(map), x) = reverse(x) +_tryreverse(m::typeof(map), x) = _reverse(x) + +# Fallback +_reverse(x) = reverse(x) + +# Known cases in the standard library on which `reverse` errors (issue #355) +_reverse(x::LowerTriangular) = UpperTriangular(reverse(parent(x))) +_reverse(x::UpperTriangular) = LowerTriangular(reverse(parent(x))) +_reverse(x::UnitLowerTriangular) = UnitUpperTriangular(reverse(parent(x))) +_reverse(x::UnitUpperTriangular) = UnitLowerTriangular(reverse(parent(x))) +_reverse(x::Hermitian) = Hermitian(reverse(collect(x))) +_reverse(x::Symmetric) = Symmetric(reverse(collect(x))) # With mismatched lengths, map stops early. With mismatched shapes, it makes a vector. # So we keep axes(x) to restore gradient dx to its full length & correct shape. From 4835c520c1360178505d5c0a43d1e308621f9b76 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Tue, 14 Mar 2023 22:32:58 +0100 Subject: [PATCH 02/10] Add unit tests --- test/lib/array.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/test/lib/array.jl b/test/lib/array.jl index d02e9f9d3..417223845 100644 --- a/test/lib/array.jl +++ b/test/lib/array.jl @@ -65,3 +65,13 @@ end end @test gradient(f_comprehension, w)[1] == ones(5) end + +@testset "_reverse" begin + m = [1 2 3; 4 5 6; 7 8 9] + @testset for wrapper in [ + Hermitian, Symmetric, LowerDiagonal, UpperDiagonal, + ] + M = wrapper(m) + @test collect(_reverse(M)) == _reverse(collect(M)) + end +end From 30229d288d623c51ce01c2823ee32938ed3f458b Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Tue, 14 Mar 2023 22:36:51 +0100 Subject: [PATCH 03/10] Correct issue number --- src/lib/array.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index ef28d145c..abcc7fedd 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -175,7 +175,7 @@ _tryreverse(m::typeof(map), x) = _reverse(x) # Fallback _reverse(x) = reverse(x) -# Known cases in the standard library on which `reverse` errors (issue #355) +# Known cases in the standard library on which `reverse` errors (issue #1393) _reverse(x::LowerTriangular) = UpperTriangular(reverse(parent(x))) _reverse(x::UpperTriangular) = LowerTriangular(reverse(parent(x))) _reverse(x::UnitLowerTriangular) = UnitUpperTriangular(reverse(parent(x))) From 0ec6c1ac1dd5695a1a8205eeaef92a6e5ba55c69 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Tue, 14 Mar 2023 22:38:57 +0100 Subject: [PATCH 04/10] Label testset --- test/lib/array.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/lib/array.jl b/test/lib/array.jl index 417223845..2f1bd9ac8 100644 --- a/test/lib/array.jl +++ b/test/lib/array.jl @@ -68,7 +68,7 @@ end @testset "_reverse" begin m = [1 2 3; 4 5 6; 7 8 9] - @testset for wrapper in [ + @testset "$wrapper" for wrapper in [ Hermitian, Symmetric, LowerDiagonal, UpperDiagonal, ] M = wrapper(m) From c453471c9b3387bb4a376c139190cd4eea5fcd3c Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Tue, 14 Mar 2023 22:44:06 +0100 Subject: [PATCH 05/10] Add missing wrappers --- test/lib/array.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/lib/array.jl b/test/lib/array.jl index 2f1bd9ac8..8390f1b5d 100644 --- a/test/lib/array.jl +++ b/test/lib/array.jl @@ -69,7 +69,8 @@ end @testset "_reverse" begin m = [1 2 3; 4 5 6; 7 8 9] @testset "$wrapper" for wrapper in [ - Hermitian, Symmetric, LowerDiagonal, UpperDiagonal, + Hermitian, Symmetric, LowerDiagonal, UpperDiagonal, + UnitLowerDiagonal, UnitUpperDiagonal, ] M = wrapper(m) @test collect(_reverse(M)) == _reverse(collect(M)) From c0e75c2340fae13172edda1ca3ea4eb2c3ce24af Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Tue, 14 Mar 2023 22:59:30 +0100 Subject: [PATCH 06/10] Avoid `collect` in `_reverse` for `Hermitian` and `Symmetric` Co-authored-by: David Widmann --- src/lib/array.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index abcc7fedd..ef72cf37b 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -180,8 +180,8 @@ _reverse(x::LowerTriangular) = UpperTriangular(reverse(parent(x))) _reverse(x::UpperTriangular) = LowerTriangular(reverse(parent(x))) _reverse(x::UnitLowerTriangular) = UnitUpperTriangular(reverse(parent(x))) _reverse(x::UnitUpperTriangular) = UnitLowerTriangular(reverse(parent(x))) -_reverse(x::Hermitian) = Hermitian(reverse(collect(x))) -_reverse(x::Symmetric) = Symmetric(reverse(collect(x))) +_reverse(x::Hermitian) = Hermitian(_reverse(x.data), x.uplo == 'U' ? :L : :U) +_reverse(x::Symmetric) = Symmetric(_reverse(x.data), x.uplo == 'U' ? :L : :U) # With mismatched lengths, map stops early. With mismatched shapes, it makes a vector. # So we keep axes(x) to restore gradient dx to its full length & correct shape. From 0f8bc3e6d1323ad6f934bbcac1b8198449c14e69 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Tue, 14 Mar 2023 23:35:15 +0100 Subject: [PATCH 07/10] Use `_reverse` instead of `reverse` Co-authored-by: David Widmann --- src/lib/array.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index ef72cf37b..b7c6d45ee 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -176,10 +176,10 @@ _tryreverse(m::typeof(map), x) = _reverse(x) _reverse(x) = reverse(x) # Known cases in the standard library on which `reverse` errors (issue #1393) -_reverse(x::LowerTriangular) = UpperTriangular(reverse(parent(x))) -_reverse(x::UpperTriangular) = LowerTriangular(reverse(parent(x))) -_reverse(x::UnitLowerTriangular) = UnitUpperTriangular(reverse(parent(x))) -_reverse(x::UnitUpperTriangular) = UnitLowerTriangular(reverse(parent(x))) +_reverse(x::LowerTriangular) = UpperTriangular(_reverse(parent(x))) +_reverse(x::UpperTriangular) = LowerTriangular(_reverse(parent(x))) +_reverse(x::UnitLowerTriangular) = UnitUpperTriangular(_reverse(parent(x))) +_reverse(x::UnitUpperTriangular) = UnitLowerTriangular(_reverse(parent(x))) _reverse(x::Hermitian) = Hermitian(_reverse(x.data), x.uplo == 'U' ? :L : :U) _reverse(x::Symmetric) = Symmetric(_reverse(x.data), x.uplo == 'U' ? :L : :U) From 188eb16c21b7c2b834c6006815d9f6b6946b9d0f Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Wed, 15 Mar 2023 09:11:36 +0100 Subject: [PATCH 08/10] Fix wrong names :) Co-authored-by: David Widmann --- test/lib/array.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/lib/array.jl b/test/lib/array.jl index 8390f1b5d..d8850edd7 100644 --- a/test/lib/array.jl +++ b/test/lib/array.jl @@ -69,8 +69,8 @@ end @testset "_reverse" begin m = [1 2 3; 4 5 6; 7 8 9] @testset "$wrapper" for wrapper in [ - Hermitian, Symmetric, LowerDiagonal, UpperDiagonal, - UnitLowerDiagonal, UnitUpperDiagonal, + Hermitian, Symmetric, LowerTriangular, UpperTriangular, + UnitLowerTriangular, UnitUpperTriangular, ] M = wrapper(m) @test collect(_reverse(M)) == _reverse(collect(M)) From b1d7d98fd89820803b7aab8996fa05d4236de199 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Wed, 15 Mar 2023 09:27:34 +0100 Subject: [PATCH 09/10] Add end user test case --- test/lib/array.jl | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/test/lib/array.jl b/test/lib/array.jl index d8850edd7..497e2039e 100644 --- a/test/lib/array.jl +++ b/test/lib/array.jl @@ -1,4 +1,6 @@ using ChainRulesTestUtils +using LinearAlgebra: Diagonal, Hermitian, LowerTriangular, UpperTriangular +using LinearAlgebra: UnitLowerTriangular, UnitUpperTriangular using Zygote: ZygoteRuleConfig, _pullback # issue 897 @@ -76,3 +78,21 @@ end @test collect(_reverse(M)) == _reverse(collect(M)) end end + +@testset "rrule for `map`" begin + @testset "MWE from #1393" begin + # https://github.com/FluxML/Zygote.jl/issues/1393#issuecomment-1468496804 + struct Foo1393 x::Float64 end + (f::Foo1393)(x) = f.x * x + x = randn(5, 5) + out, pb = Zygote.pullback(x -> map(Foo1393(5.0), x), x) + @testset "$wrapper" for wrapper in [ + Hermitian, Symmetric, LowerTriangular, UpperTriangular, + UnitLowerTriangular, UnitUpperTriangular, + ] + m = wrapper(rand(5, 5)) + res = only(pb(m)) + @test res == 5m + end + end +end From afdcfdde25bea08848a9012df2515708cde909d6 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Wed, 15 Mar 2023 10:20:50 +0100 Subject: [PATCH 10/10] Add `using Zygote: _reverse` --- test/lib/array.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/lib/array.jl b/test/lib/array.jl index 497e2039e..889301c1e 100644 --- a/test/lib/array.jl +++ b/test/lib/array.jl @@ -1,7 +1,7 @@ using ChainRulesTestUtils using LinearAlgebra: Diagonal, Hermitian, LowerTriangular, UpperTriangular using LinearAlgebra: UnitLowerTriangular, UnitUpperTriangular -using Zygote: ZygoteRuleConfig, _pullback +using Zygote: ZygoteRuleConfig, _pullback, _reverse # issue 897