diff --git a/src/stage1/generated.jl b/src/stage1/generated.jl index c1624046..9f84bc88 100644 --- a/src/stage1/generated.jl +++ b/src/stage1/generated.jl @@ -225,6 +225,25 @@ function (::∂⃖{N})(f::T, args...) where {T, N} end end end +function (::∂⃖{N})(::typeof(Core.kwcall), kwargs, f::T, args...) where {T, N} + if N == 1 + # Base case (inlined to avoid ambiguities with manually specified + # higher order rules) + z = rrule(DiffractorRuleConfig(), KwFunc(f), kwargs, f, args...) + if z === nothing + return ∂⃖recurse{1}()(f, args..., kwargs...) + end + return z + else + ∂⃖p = ∂⃖{N-1}() + @destruct z, z̄ = ∂⃖p(rrule, f, args...; kwargs...) + if z === nothing + return ∂⃖recurse{N}()(f, args...; kwargs...) + else + return ∂⃖rrule{N}()(z, z̄) + end + end +end function ChainRulesCore.rrule_via_ad(::DiffractorRuleConfig, f::T, args...) where {T} Tuple{Any, Any}(∂⃖{1}()(f, args...)) @@ -244,6 +263,10 @@ struct KwFunc{T,S} end (kw::KwFunc)(args...) = kw.kwf(args...) +function ChainRulesCore.rrule(::typeof(Core.kwcall), kwargs, f, args...) + rrule(KwFunc(f), kwargs, f, args...) +end + function ChainRulesCore.rrule(::typeof(Core.kwfunc), f) KwFunc(f), Δ->(NoTangent(), Δ) end diff --git a/test/gradcheck.jl b/test/gradcheck.jl index d003c82d..f3bb8e9a 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -99,10 +99,11 @@ end @test gradcheck(x -> sum(i->x[i], 1:length(x)), randn(10)) # issue #231 @test gradcheck(x -> sum((i->x[i]).(1:length(x))), randn(10)) @test gradcheck(X -> sum(x -> x^2, X), randn(10)) + @test jacobicheck(x -> sum(x, dims = (2, 3)), (3,4,5)) + @test jacobicheck(x -> sum(abs2, x; dims=1), randn(4, 3, 2)) # MethodError: no method matching copy(::Nothing) - @test_broken jacobicheck(x -> sum(x, dims = (2, 3)), (3,4,5)) - @test_broken jacobicheck(x -> sum(abs2, x; dims=1), randn(4, 3, 2)) + # TODO: interesting that this is the only one that is not fixed @test_broken gradcheck(X -> sum(sum(x -> x^2, X; dims=1)), randn(10)) # issue #681 # Non-differentiable sum of booleans @@ -119,23 +120,15 @@ end @test gradcheck(x -> prod(x), (3,4)) @test gradient(x -> prod(x), (1,2,3))[1] == (6,3,2) - - # MethodError: no method matching copy(::Nothing) - @test_broken jacobicheck(x -> prod(x, dims = (2, 3)), (3,4,5)) + @test jacobicheck(x -> prod(x, dims = (2, 3)), (3,4,5)) end @testset "cumsum" begin @test jacobicheck(x -> cumsum(x), (4,)) - - # TypeError: in typeassert, expected Int64, got a value of type Nothing - @test_broken jacobicheck(x -> cumsum(x, dims=2), (3,4,5)) - @test_broken jacobicheck(x -> cumsum(x, dims=3), (3,4)) # trivial - - # MethodError: no method matching copy(::Nothing) - @test_broken jacobicheck(x -> cumsum(x, dims=1), (3,)) - - # Rewrite reached intrinsic function bitcast. Missing rule? - @test_broken jacobicheck(x -> cumsum(x, dims=3), (5,)) # trivial + @test jacobicheck(x -> cumsum(x, dims=2), (3,4,5)) + @test jacobicheck(x -> cumsum(x, dims=3), (3,4)) # trivial + @test jacobicheck(x -> cumsum(x, dims=1), (3,)) + @test jacobicheck(x -> cumsum(x, dims=3), (5,)) # trivial end @testset "getindex" begin @@ -221,8 +214,7 @@ end @test jacobicheck(x -> reverse(x), rand(17)) @test jacobicheck(x -> reverse(x, 8), rand(17)) @test jacobicheck(x -> reverse(x, 8, 13), rand(17)) - # Rewrite reached intrinsic function bitcast. Missing rule? - @test_broken jacobicheck(x -> reverse(x, dims=2), rand(17, 42)) + @test jacobicheck(x -> reverse(x, dims=2), rand(17, 42)) end @testset "permutedims" begin @@ -237,11 +229,9 @@ end end @testset "repeat" begin - # MethodError: no method matching copy(::Nothing) - @test_broken jacobicheck(x -> repeat(x; inner=2), rand(5)) - @test_broken jacobicheck(x -> repeat(x; inner=2, outer=3), rand(5)) - @test_broken jacobicheck(x -> repeat(x; inner=(2,2,1), outer=(1,1,3)), rand(5,4,3)) - + @test jacobicheck(x -> repeat(x; inner=2), rand(5)) + @test jacobicheck(x -> repeat(x; inner=2, outer=3), rand(5)) + @test jacobicheck(x -> repeat(x; inner=(2,2,1), outer=(1,1,3)), rand(5,4,3)) @test jacobicheck(x -> repeat(x, 3), rand(5)) @test jacobicheck(x -> repeat(x, 2, 3), rand(5)) @test jacobicheck(x -> repeat(x, 5), rand(5,7)) @@ -453,11 +443,10 @@ end @test gradient(v->sort(v)[i], [1.,2,3])[1][correct[2][i]] == 1 end for i = 1:3 - # Rewrite reached intrinsic function bitcast. Missing rule? - @test_broken gradient(v->sort(v,by=x->x%10)[i], [11,2,99])[1][correct[3][i]] == 1 - @test_broken gradient(v->sort(v,by=x->x%10)[i], [2,11,99])[1][correct[4][i]] == 1 - @test_broken gradient(v->sort(v,rev=true)[i], [3.,1,2])[1][correct[5][i]] == 1 - @test_broken gradient(v->sort(v,rev=true)[i], [1.,2,3])[1][correct[6][i]] == 1 + @test gradient(v->sort(v,by=x->x%10)[i], [11,2,99])[1][correct[3][i]] == 1 + @test gradient(v->sort(v,by=x->x%10)[i], [2,11,99])[1][correct[4][i]] == 1 + @test gradient(v->sort(v,rev=true)[i], [3.,1,2])[1][correct[5][i]] == 1 + @test gradient(v->sort(v,rev=true)[i], [1.,2,3])[1][correct[6][i]] == 1 end end @@ -473,27 +462,21 @@ end @testset "maximum" begin @test jacobicheck(maximum, rand(2, 3)) - - # MethodError: no method matching copy(::Nothing) - @test_broken jacobicheck(x -> maximum(x, dims=1), rand(2, 3)) - @test_broken jacobicheck(x -> maximum(x, dims=3), rand(2, 3, 4)) - @test_broken jacobicheck(x -> maximum(x, dims=[1, 2]), rand(2, 3, 4)) - + @test jacobicheck(x -> maximum(x, dims=1), rand(2, 3)) + @test jacobicheck(x -> maximum(x, dims=3), rand(2, 3, 4)) + @test jacobicheck(x -> maximum(x, dims=[1, 2]), rand(2, 3, 4)) @test gradient(x -> 1 / maximum(x), [1., 2, 3])[1] == [0, 0, -1/9] end @testset "minimum" begin @test jacobicheck(minimum, rand(2, 3)) - - # MethodError: no method matching copy(::Nothing) - @test_broken jacobicheck(x -> minimum(x, dims=1), rand(2, 3)) - @test_broken jacobicheck(x -> minimum(x, dims=2), rand(2, 3)) + @test jacobicheck(x -> minimum(x, dims=1), rand(2, 3)) + @test jacobicheck(x -> minimum(x, dims=2), rand(2, 3)) end @testset "dropdims" begin # https://github.com/JuliaDiff/Diffractor.jl/issues/72 - # TypeError: in typeassert, expected Int64, got a value of type Nothing - @test_broken jacobicheck(x -> dropdims(x, dims = 3), rand(2, 2, 1, 2)) - @test_broken jacobicheck(x -> dropdims(x, dims = (2, 3)), rand(2, 1, 1, 3)) + @test jacobicheck(x -> dropdims(x, dims = 3), rand(2, 2, 1, 2)) + @test jacobicheck(x -> dropdims(x, dims = (2, 3)), rand(2, 1, 1, 3)) end @testset "vcat" begin @@ -544,20 +527,19 @@ end end @testset "cat(...; dims = $dim)" for dim in 1:3 - # Rewrite reached intrinsic function bitcast. Missing rule? catdim = (x...) -> cat(x..., dims = dim) - @test_broken jacobicheck(catdim, rand(4,1)) - @test_broken jacobicheck(catdim, rand(5), rand(5,1)) - @test_broken jacobicheck(catdim, rand(2,5), rand(2,5), rand(2,5)) + @test jacobicheck(catdim, rand(4,1)) + @test jacobicheck(catdim, rand(5), rand(5,1)) + @test jacobicheck(catdim, rand(2,5), rand(2,5), rand(2,5)) catdimval = (x...) -> cat(x...; dims = Val(dim)) - @test_broken jacobicheck(catdimval, rand(5), rand(5)) - @test_broken jacobicheck(catdimval, rand(2,5), rand(2,5,1)) + @test jacobicheck(catdimval, rand(5), rand(5)) + @test jacobicheck(catdimval, rand(2,5), rand(2,5,1)) # one empty dim == 1 || continue - @test_broken jacobicheck(catdim, rand(0,5,3), rand(2,5,3)) + @test jacobicheck(catdim, rand(0,5,3), rand(2,5,3)) end @testset "one(s) and zero(s)" begin @@ -574,7 +556,7 @@ end end @testset "broadcast" begin - @test gradient(x -> sum(sin.(x)), Diagonal([0,pi/2,pi]))[1] ≈ [1 0 0; 0 0 0; 0 0 -1] + @test_broken gradient(x -> sum(sin.(x)), Diagonal([0,pi/2,pi]))[1] ≈ [1 0 0; 0 0 0; 0 0 -1] # mixing arrays & Ref(array) a = rand(3) @@ -586,8 +568,7 @@ end # tests for https://github.com/FluxML/Zygote.jl/issues/724 x1 = rand(3, 3) @test gradient(x -> sum(x .== 0.5), x1) |> only |> isZero - # MethodError: no method matching copy(::Nothing) - @test_broken gradient(x -> sum(x .* (x .== maximum(x, dims=1))), x1)[1] == (x1 .== maximum(x1, dims=1)) + @test gradient(x -> sum(x .* (x .== maximum(x, dims=1))), x1)[1] == (x1 .== maximum(x1, dims=1)) # tests for un-broadcasting *, / via scalar rules @test all(gradient((x,y) -> sum(x .* y), [1,2], 5) .≈ ([5, 5], 3)) @@ -620,7 +601,7 @@ end @test_broken jacobicheck(+, A, B, A) @test jacobicheck(-, A) # in typeassert, expected Int64, got a value of type Nothing - @test_broken jacobicheck(-, A, B) + @test jacobicheck(-, A, B) end end