diff --git a/Manifest.toml b/Manifest.toml index cda0ac4d..eb00e5c8 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -27,9 +27,9 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" [[deps.ChainRules]] deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "Statistics", "StructArrays"] -git-tree-sha1 = "c46adabdd0348f0ee8de91142cfc4a72a613ac0a" +git-tree-sha1 = "fdde4d8a31cf82b1d136cf6cb53924e8744a832b" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.46.1" +version = "1.47.0" [[deps.ChainRulesCore]] deps = ["Compat", "LinearAlgebra", "SparseArrays"] @@ -265,9 +265,9 @@ version = "1.10.0" [[deps.StaticArrays]] deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"] -git-tree-sha1 = "129703d62117c374c4f2db6d13a027741c46eafd" +git-tree-sha1 = "cee507162ecbb677450f20058ca83bd559b6b752" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.5.13" +version = "1.5.14" [[deps.StaticArraysCore]] git-tree-sha1 = "6b7ba252635a5eff6a0b0664a41ee140a1c9e72a" diff --git a/src/extra_rules.jl b/src/extra_rules.jl index a2aa7dc5..0b0e8b51 100644 --- a/src/extra_rules.jl +++ b/src/extra_rules.jl @@ -79,7 +79,7 @@ function (::∂⃖{N})(f::typeof(*), args...) where {N} end return z else - ∂⃖p = ∂⃖{minus1(N)}() + ∂⃖p = ∂⃖{N-1}() @destruct z, z̄ = ∂⃖p(rrule_times, f, args...) if z === nothing return ∂⃖recurse{N}()(f, args...) @@ -130,15 +130,15 @@ end struct NonDiffEven{N, O, P}; end struct NonDiffOdd{N, O, P}; end -(::NonDiffOdd{N, O, P})(Δ) where {N, O, P} = (ntuple(_->ZeroTangent(), N), NonDiffEven{N, plus1(O), P}()) -(::NonDiffEven{N, O, P})(Δ...) where {N, O, P} = (ZeroTangent(), NonDiffOdd{N, plus1(O), P}()) +(::NonDiffOdd{N, O, P})(Δ) where {N, O, P} = (ntuple(_->ZeroTangent(), N), NonDiffEven{N, O+1, P}()) +(::NonDiffEven{N, O, P})(Δ...) where {N, O, P} = (ZeroTangent(), NonDiffOdd{N, O+1, P}()) (::NonDiffOdd{N, O, O})(Δ) where {N, O} = ntuple(_->ZeroTangent(), N) # This should not happen (::NonDiffEven{N, O, O})(Δ...) where {N, O} = error() -@Base.pure function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(Core.apply_type), head, args...) - Core.apply_type(head, args...), NonDiffOdd{plus1(plus1(length(args))), 1, 1}() +@Base.assume_effects :total function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(Core.apply_type), head, args...) + Core.apply_type(head, args...), NonDiffOdd{length(args)+2, 1, 1}() end function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(Core.tuple), args...) diff --git a/src/interface.jl b/src/interface.jl index f7fae6aa..2aa4b855 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -64,7 +64,7 @@ dx(x) = error("Cotangent space not defined for `$(typeof(x))`. Try a real-valued For `x` in a one dimensional manifold, map x to the trivial, unital, 1st order tangent bundle. It should hold that `∀x ⟨∂x(x), dx(x)⟩ = 1` """ -∂x(x::Real) = ExplicitTangentBundle{1}(x, (one(x),)) +∂x(x::Real) = TaylorBundle{1}(x, (one(x),)) ∂x(x) = error("Tangent space not defined for `$(typeof(x)).") struct ∂xⁿ{N}; end @@ -143,11 +143,9 @@ Base.show(io::IO, f::PrimeDerivativeBack{N}) where {N} = print(io, f.f, "'"^N) # This improves performance for nested derivatives by short cutting some # recursion into the PrimeDerivative constructor -@Base.pure minus1(N) = N - 1 -@Base.pure plus1(N) = N + 1 -lower_pd(f::PrimeDerivativeBack{N,T}) where {N,T} = PrimeDerivativeBack{minus1(N),T}(getfield(f, :f)) +lower_pd(f::PrimeDerivativeBack{N,T}) where {N,T} = PrimeDerivativeBack{N-1,T}(getfield(f, :f)) lower_pd(f::PrimeDerivativeBack{1}) = getfield(f, :f) -raise_pd(f::PrimeDerivativeBack{N,T}) where {N,T} = PrimeDerivativeBack{plus1(N),T}(getfield(f, :f)) +raise_pd(f::PrimeDerivativeBack{N,T}) where {N,T} = PrimeDerivativeBack{N+1,T}(getfield(f, :f)) ChainRulesCore.rrule(::typeof(lower_pd), f) = lower_pd(f), Δ->(ZeroTangent(), Δ) ChainRulesCore.rrule(::typeof(raise_pd), f) = raise_pd(f), Δ->(ZeroTangent(), Δ) @@ -170,8 +168,8 @@ end PrimeDerivativeFwd(f) = PrimeDerivativeFwd{1, typeof(f)}(f) PrimeDerivativeFwd(f::PrimeDerivativeFwd{N, T}) where {N, T} = raise_pd(f) -lower_pd(f::PrimeDerivativeFwd{N,T}) where {N,T} = (error(); PrimeDerivativeFwd{minus1(N),T}(getfield(f, :f))) -raise_pd(f::PrimeDerivativeFwd{N,T}) where {N,T} = PrimeDerivativeFwd{plus1(N),T}(getfield(f, :f)) +lower_pd(f::PrimeDerivativeFwd{N,T}) where {N,T} = (error(); PrimeDerivativeFwd{N-1,T}(getfield(f, :f))) +raise_pd(f::PrimeDerivativeFwd{N,T}) where {N,T} = PrimeDerivativeFwd{N+1,T}(getfield(f, :f)) (f::PrimeDerivativeFwd{0})(x) = getfield(f, :f)(x) diff --git a/src/jet.jl b/src/jet.jl index 5584e32d..ab46653f 100644 --- a/src/jet.jl +++ b/src/jet.jl @@ -1,5 +1,5 @@ """ - struct Jet{T, N} + struct Jet{S, T, N} Represents the truncated (N-1)-th order Taylor series @@ -15,8 +15,8 @@ For a jet `j`, several operations are supported: derivatives. Mathematically this corresponds to an infinitessimal ball around `a`. """ -struct Jet{T, N} - a::T +struct Jet{S, T, N} + a::S f₀::T fₙ::NTuple{N, T} end @@ -25,13 +25,13 @@ function ChainRulesCore.rrule(::typeof(Base.getproperty), j::Jet, s) error("Raw getproperty not allowed in AD code") end -function Base.:+(j1::Jet{T, N}, j2::Jet{T, N}) where {T, N} +function Base.:+(j1::Jet{S, T, N}, j2::Jet{S, T, N}) where {S, T, N} @assert j1.a === j2.a - Jet{T, N}(j1.a, j1.f₀ + j2.f₀, map(+, j1.fₙ, j2.fₙ)) + Jet{S, T, N}(j1.a, j1.f₀ + j2.f₀, map(+, j1.fₙ, j2.fₙ)) end -function Base.:+(j::Jet{T, N}, x::T) where {T, N} - Jet{T, N}(j.a, j.f₀+x, j.fₙ) +function Base.:+(j::Jet{S, T, N}, x::T) where {S, T, N} + Jet{S, T, N}(j.a, j.f₀+x, j.fₙ) end struct One; end @@ -44,9 +44,9 @@ function ChainRulesCore.rrule(::typeof(+), j::Jet, x::One) j + x, Δ->(NoTangent(), One(), ZeroTangent()) end -function Base.zero(j::Jet{T, N}) where {T, N} +function Base.zero(j::Jet{S, T, N}) where {S, T, N} let z = zero(j[0]) - Jet{T, N}(j.a, z, + Jet{S, T, N}(j.a, z, ntuple(_->z, N)) end end @@ -54,18 +54,18 @@ function ChainRulesCore.rrule(::typeof(Base.zero), j::Jet) zero(j), Δ->(NoTangent(), ZeroTangent()) end -function Base.getindex(j::Jet{T, N}, i::Integer) where {T, N} +function Base.getindex(j::Jet{S, T, N}, i::Integer) where {S, T, N} (0 <= i <= N) || throw(BoundsError(j, i)) i == 0 && return j.f₀ return j.fₙ[i] end -function deriv(j::Jet{T, N}) where {T, N} - Jet{T, N-1}(j.a, j.fₙ[1], Base.tail(j.fₙ)) +function deriv(j::Jet{S, T, N}) where {S, T, N} + Jet{S, T, N-1}(j.a, j.fₙ[1], Base.tail(j.fₙ)) end -function integrate(j::Jet{T, N}) where {T, N} - Jet{T, N+1}(j.a, zero(j.f₀), tuple(j.f₀, j.fₙ...)) +function integrate(j::Jet{S, T, N}) where {S, T, N} + Jet{S, T, N+1}(j.a, zero(j.f₀), tuple(j.f₀, j.fₙ...)) end deriv(::NoTangent) = NoTangent() @@ -187,9 +187,8 @@ function (∂⃖ₙ::∂⃖{N})(::typeof(map), f, a::Array) where {N} ∂f = ∂☆{N}()(ZeroBundle{N}(f), TaylorBundle{N}(x, (one(x), (zero(x) for i = 1:(N-1))...,))) - @assert isa(∂f, TaylorBundle) || isa(∂f, ExplicitTangentBundle{1}) - Jet{typeof(x), N}(x, ∂f.primal, - isa(∂f, ExplicitTangentBundle) ? ∂f.tangent.partials : ∂f.tangent.coeffs) + @assert isa(∂f, TaylorBundle) + Jet{typeof(x), typeof(x), N}(x, ∂f.primal, ∂f.tangent.coeffs) end ∂⃖ₙ(mapev, js, a) end @@ -239,7 +238,7 @@ expressions for the t′ᵢ that are hopefully easier on the compiler. end...) end -@generated function (j::Jet{T, N} where T)(x::TaylorBundle{M}) where {N, M} +@generated function (j::Jet{S, T, N} where {S, T})(x::TaylorBundle{M}) where {N, M} O = min(M,N) quote domain_check(j, x.primal) @@ -248,13 +247,3 @@ end ($((:(jet_taylor_ev(Val{$i}(), coeffs, j)) for i = 1:O)...),)) end end - -function (j::Jet{T, 1} where T)(x::ExplicitTangentBundle{1}) - domain_check(j, x.primal) - coeffs = x.tangent.partials - ExplicitTangentBundle{1}(j[0], (jet_taylor_ev(Val{1}(), coeffs, j),)) -end - -function (j::Jet{T, N} where T)(x::ExplicitTangentBundle{N, M}) where {N, M} - error("TODO") -end diff --git a/src/stage1/forward.jl b/src/stage1/forward.jl index b5ef1f03..e8f04a18 100644 --- a/src/stage1/forward.jl +++ b/src/stage1/forward.jl @@ -1,5 +1,4 @@ partial(x::TangentBundle, i) = partial(getfield(x, :tangent), i) -partial(x::ExplicitTangent, i) = getfield(getfield(x, :partials), i) partial(x::TaylorTangent, i) = getfield(getfield(x, :coeffs), i) partial(x::UniformTangent, i) = getfield(x, :val) partial(x::ProductTangent, i) = ProductTangent(map(x->partial(x, i), getfield(x, :factors))) @@ -23,22 +22,13 @@ my_frule(::ZeroBundle{1, typeof(my_frule)}, args::ATB{1}...) = nothing (::∂☆{N})(::ZeroBundle{N, typeof(my_frule)}, ::ZeroBundle{N, ZeroBundle{1, typeof(my_frule)}}, args::ATB{N}...) where {N} = ZeroBundle{N}(nothing) shuffle_down(b::UniformBundle{N, B, U}) where {N, B, U} = - UniformBundle{minus1(N), <:Any, U}(UniformBundle{1, B, U}(b.primal, b.tangent.val), b.tangent.val) - -function shuffle_down(b::ExplicitTangentBundle{N, B}) where {N, B} - # N.B: This depends on the special properties of the canonical tangent index order - ExplicitTangentBundle{N-1}( - ExplicitTangentBundle{1}(b.primal, (partial(b, 1),)), - ntuple(2^(N-1)-1) do i - ExplicitTangentBundle{1}(partial(b, 2*i), (partial(b, 2*i+1),)) - end) -end + UniformBundle{N-1, <:Any, U}(UniformBundle{1, B, U}(b.primal, b.tangent.val), b.tangent.val) function shuffle_down(b::TaylorBundle{N, B}) where {N, B} TaylorBundle{N-1}( - ExplicitTangentBundle{1}(b.primal, (b.tangent.coeffs[1],)), + TaylorBundle{1}(b.primal, (b.tangent.coeffs[1],)), ntuple(N-1) do i - ExplicitTangentBundle{1}(b.tangent.coeffs[i], (b.tangent.coeffs[i+1],)) + TaylorBundle{1}(b.tangent.coeffs[i], (b.tangent.coeffs[i+1],)) end) end @@ -54,40 +44,16 @@ end function shuffle_up(r::CompositeBundle{1}) z₀ = primal(r.tup[1]) z₁ = partial(r.tup[1], 1) - z₂ = primal(r.tup[2]) z₁₂ = partial(r.tup[2], 1) - if z₁ == z₂ - return TaylorBundle{2}(z₀, (z₁, z₁₂)) - else - return ExplicitTangentBundle{2}(z₀, (z₁, z₂, z₁₂)) - end + return TaylorBundle{2}(z₀, (z₁, z₁₂)) end -function taylor_compatible(a::ATB{N}, b::ATB{N}) where {N} - primal(b) === a[TaylorTangentIndex(1)] || return false - return all(1:(N-1)) do i - b[TaylorTangentIndex(i)] === a[TaylorTangentIndex(i+1)] - end -end - -# Check whether the tangent bundle element is taylor-like -isswifty(::TaylorBundle) = true -isswifty(::UniformBundle) = true -isswifty(b::CompositeBundle) = all(isswifty, b.tup) -isswifty(::Any) = false - function shuffle_up(r::CompositeBundle{N}) where {N} a, b = r.tup - if isswifty(a) && isswifty(b) && taylor_compatible(a, b) - return TaylorBundle{N+1}(primal(a), - ntuple(i->i == N+1 ? - b[TaylorTangentIndex(i-1)] : a[TaylorTangentIndex(i)], - N+1)) - else - return TangentBundle{N+1}(r.tup[1].primal, - (r.tup[1].tangent.partials..., primal(b), - ntuple(i->partial(b,i), 2^(N+1)-1)...)) - end + return TaylorBundle{N+1}(primal(a), + ntuple(i->i == N+1 ? + b[TaylorTangentIndex(i-1)] : a[TaylorTangentIndex(i)], + N+1)) end function shuffle_up(r::UniformBundle{N, B, U}) where {N, B, U} @@ -118,14 +84,14 @@ function (::∂☆internal{1})(args::AbstractTangentBundle{1}...) end function ChainRulesCore.frule_via_ad(::DiffractorRuleConfig, partials, args...) - bundles = map((p,a) -> ExplicitTangentBundle{1}(a, (p,)), partials, args) + bundles = map((p,a) -> TaylorBundle{1}(a, (p,)), partials, args) result = ∂☆internal{1}()(bundles...) primal(result), first_partial(result) end function (::∂☆shuffle{N})(args::AbstractTangentBundle{N}...) where {N} - ∂☆p = ∂☆{minus1(N)}() - ∂☆p(ZeroBundle{minus1(N)}(my_frule), map(shuffle_down, args)...) + ∂☆p = ∂☆{N-1}() + ∂☆p(ZeroBundle{N-1}(my_frule), map(shuffle_down, args)...) end function (::∂☆internal{N})(args::AbstractTangentBundle{N}...) where {N} @@ -139,18 +105,6 @@ end (::∂☆{N})(args::AbstractTangentBundle{N}...) where {N} = ∂☆internal{N}()(args...) # Special case rules for performance -@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TangentBundle{N}, s::AbstractTangentBundle{N}) where {N} - s = primal(s) - ExplicitTangentBundle{N}(getfield(primal(x), s), - map(x->lifted_getfield(x, s), x.tangent.partials)) -end - -@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TangentBundle{N}, s::ATB{N}, inbounds::ATB{N}) where {N} - s = primal(s) - ExplicitTangentBundle{N}(getfield(primal(x), s, primal(inbounds)), - map(x->lifted_getfield(x, s), x.tangent.partials)) -end - @Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TaylorBundle{N}, s::AbstractTangentBundle{N}) where {N} s = primal(s) TaylorBundle{N}(getfield(primal(x), s), diff --git a/src/stage1/generated.jl b/src/stage1/generated.jl index 0e723135..cc1325a2 100644 --- a/src/stage1/generated.jl +++ b/src/stage1/generated.jl @@ -101,12 +101,12 @@ struct ∂⃖weaveInnerOdd{N, O}; b̄; end end @Base.constprop :aggressive function (w::∂⃖weaveInnerOdd{N, O})(Δ) where {N, O} @destruct c, c̄ = w.b̄(Δ...) - return (c̄, c), ∂⃖weaveInnerEven{plus1(N), O}() + return (c̄, c), ∂⃖weaveInnerEven{N+1, O}() end struct ∂⃖weaveInnerEven{N, O}; end @Base.constprop :aggressive function (w::∂⃖weaveInnerEven{N, O})(Δ′, x...) where {N, O} @destruct y, ȳ = Δ′(x...) - return y, ∂⃖weaveInnerOdd{plus1(N), O}(ȳ) + return y, ∂⃖weaveInnerOdd{N+1, O}(ȳ) end struct ∂⃖weaveOuterOdd{N, O}; end @@ -115,15 +115,15 @@ struct ∂⃖weaveOuterOdd{N, O}; end end @Base.constprop :aggressive function (w::∂⃖weaveOuterOdd{N, O})((Δ′′, Δ′′′)) where {N, O} @destruct α, ᾱ = Δ′′′(Δ′′) - return (NoTangent(), α...), ∂⃖weaveOuterEven{plus1(N), O}(ᾱ) + return (NoTangent(), α...), ∂⃖weaveOuterEven{N+1, O}(ᾱ) end struct ∂⃖weaveOuterEven{N, O}; ᾱ end @Base.constprop :aggressive function (w::∂⃖weaveOuterEven{N, O})(Δ⁴...) where {N, O} - return w.ᾱ(Base.tail(Δ⁴)...), ∂⃖weaveOuterOdd{plus1(N), O}() + return w.ᾱ(Base.tail(Δ⁴)...), ∂⃖weaveOuterOdd{N+1, O}() end function (::∂⃖{N})(::∂⃖{1}, args...) where {N} - @destruct (a, ā) = ∂⃖{plus1(N)}()(args...) + @destruct (a, ā) = ∂⃖{N+1}()(args...) let O = c_order(N) (a, Protected{N}(@opaque Δ->begin (b, b̄) = ā(Δ) @@ -188,10 +188,10 @@ end (::∂⃖rruleD{N, N})(Δ...) where {N} = error("Should not be reached") # ∂⃖rrule -@Base.pure term_depth(N) = 2^(N-2) +term_depth(N) = 1<<(N-2) function (::∂⃖rrule{N})(z, z̄) where {N} @destruct (y, ȳ) = z - y, ∂⃖rruleA{term_depth(N), 1}(∂⃖{minus1(N)}(), ȳ, z̄) + y, ∂⃖rruleA{term_depth(N), 1}(∂⃖{N-1}(), ȳ, z̄) end function (::∂⃖{N})(f::Core.IntrinsicFunction, args...) where {N} @@ -217,7 +217,7 @@ function (::∂⃖{N})(f::T, args...) where {T, N} end return z else - ∂⃖p = ∂⃖{minus1(N)}() + ∂⃖p = ∂⃖{N-1}() @destruct z, z̄ = ∂⃖p(rrule, f, args...) if z === nothing return ∂⃖recurse{N}()(f, args...) @@ -231,7 +231,7 @@ function ChainRulesCore.rrule_via_ad(::DiffractorRuleConfig, f::T, args...) wher Tuple{Any, Any}(∂⃖{1}()(f, args...)) end -@Base.pure function (::∂⃖{1})(::typeof(Core.apply_type), head, args...) +@Base.assume_effects :total function (::∂⃖{1})(::typeof(Core.apply_type), head, args...) return rrule(Core.apply_type, head, args...) end @@ -284,8 +284,8 @@ struct EvenOddEven{O, P, F, G}; f::F; g::G; end EvenOddEven{O, P}(f::F, g::G) where {O, P, F, G} = EvenOddEven{O, P, F, G}(f, g) struct EvenOddOdd{O, P, F, G}; f::F; g::G; end EvenOddOdd{O, P}(f::F, g::G) where {O, P, F, G} = EvenOddOdd{O, P, F, G}(f, g) -@Base.constprop :aggressive (o::EvenOddOdd{O, P, F, G})(Δ) where {O, P, F, G} = (o.f(Δ), EvenOddEven{plus1(O), P, F, G}(o.f, o.g)) -@Base.constprop :aggressive (e::EvenOddEven{O, P, F, G})(Δ...) where {O, P, F, G} = (e.g(Δ...), EvenOddOdd{plus1(O), P, F, G}(e.f, e.g)) +@Base.constprop :aggressive (o::EvenOddOdd{O, P, F, G})(Δ) where {O, P, F, G} = (o.f(Δ), EvenOddEven{O+1, P, F, G}(o.f, o.g)) +@Base.constprop :aggressive (e::EvenOddEven{O, P, F, G})(Δ...) where {O, P, F, G} = (e.g(Δ...), EvenOddOdd{O+1, P, F, G}(e.f, e.g)) @Base.constprop :aggressive (o::EvenOddOdd{O, O})(Δ) where {O} = o.f(Δ) @@ -363,11 +363,11 @@ struct ApplyOdd{O, P}; u; ∂⃖f; end struct ApplyEven{O, P}; u; ∂⃖∂⃖f; end @Base.constprop :aggressive function (a::ApplyOdd{O, P})(Δ) where {O, P} r, ∂⃖∂⃖f = a.∂⃖f(Δ) - (a.u(r), ApplyEven{plus1(O), P}(a.u, ∂⃖∂⃖f)) + (a.u(r), ApplyEven{O+1, P}(a.u, ∂⃖∂⃖f)) end @Base.constprop :aggressive function (a::ApplyEven{O, P})(_, _, ff, args...) where {O, P} r, ∂⃖∂⃖∂⃖f = Core._apply_iterate(iterate, a.∂⃖∂⃖f, (ff,), args...) - (r, ApplyOdd{plus1(O), P}(a.u, ∂⃖∂⃖∂⃖f)) + (r, ApplyOdd{O+1, P}(a.u, ∂⃖∂⃖∂⃖f)) end @Base.constprop :aggressive function (a::ApplyOdd{O, O})(Δ) where {O} r = a.∂⃖f(Δ) @@ -381,10 +381,10 @@ function (this::∂⃖{N})(::typeof(Core._apply_iterate), iterate, f, args::Unio end -@Base.pure c_order(N::Int) = 2^N - 1 +c_order(N::Int) = 1<= 2 && print(io, " + ", x.partials[2], " ∂₂") - length(x.partials) >= 3 && print(io, " + ", x.partials[3], " ∂₁ ∂₂") - length(x.partials) >= 4 && print(io, " + ", x.partials[4], " ∂₃") - length(x.partials) >= 5 && print(io, " + ", x.partials[5], " ∂₁ ∂₃") - length(x.partials) >= 6 && print(io, " + ", x.partials[6], " ∂₂ ∂₃") - length(x.partials) >= 7 && print(io, " + ", x.partials[7], " ∂₁ ∂₂ ∂₃") -end - -function Base.getindex(a::ExplicitTangentBundle{N}, b::TaylorTangentIndex) where {N} - if b.i === N - return a.tangent.partials[end] - end - error("$(typeof(a)) is not taylor-like. Taylor indexing is ambiguous") -end - const TaylorBundle{N, B, P} = TangentBundle{N, B, TaylorTangent{P}} function TaylorBundle{N, B}(primal::B, coeffs) where {N, B} check_taylor_invariants(coeffs, primal, N) - _TangentBundle(Val{N}(), primal, TaylorTangent(coeffs)) + TangentBundle{N}(primal, TaylorTangent(coeffs)) end function check_taylor_invariants(coeffs, primal, N) @@ -215,7 +160,7 @@ end @ChainRulesCore.non_differentiable check_taylor_invariants(coeffs, primal, N) function TaylorBundle{N}(primal, coeffs) where {N} - _TangentBundle(Val{N}(), primal, TaylorTangent(coeffs)) + TangentBundle{N}(primal, TaylorTangent(coeffs)) end function Base.show(io::IO, x::TaylorBundle{1}) @@ -230,25 +175,13 @@ function Base.getindex(tb::TaylorBundle, tti::CanonicalTangentIndex) tb.tangent.coeffs[count_ones(tti.i)] end -function truncate(tt::TaylorTangent, order::Val{N}) where {N} - TaylorTangent(tt.coeffs[1:N]) -end - -function truncate(ut::UniformTangent, order::Val) - ut -end - -function truncate(tb::TangentBundle, order::Val) - _TangentBundle(order, tb.primal, truncate(tb.tangent, order)) -end - const UniformBundle{N, B, U} = TangentBundle{N, B, UniformTangent{U}} -UniformBundle{N, B, U}(primal::B, partial::U) where {N,B,U} = _TangentBundle(Val{N}(), primal, UniformTangent{U}(partial)) -UniformBundle{N, B, U}(primal::B) where {N,B,U} = _TangentBundle(Val{N}(), primal, UniformTangent{U}(U.instance)) -UniformBundle{N, B}(primal::B, partial::U) where {N,B,U} = _TangentBundle(Val{N}(),primal, UniformTangent{U}(partial)) -UniformBundle{N}(primal, partial::U) where {N,U} = _TangentBundle(Val{N}(), primal, UniformTangent{U}(partial)) -UniformBundle{N, <:Any, U}(primal, partial::U) where {N, U} = _TangentBundle(Val{N}(), primal, UniformTangent{U}(U.instance)) -UniformBundle{N, <:Any, U}(primal) where {N, U} = _TangentBundle(Val{N}(), primal, UniformTangent{U}(U.instance)) +UniformBundle{N, B, U}(primal::B, partial::U) where {N,B,U} = TangentBundle{N}(primal, UniformTangent{U}(partial)) +UniformBundle{N, B, U}(primal::B) where {N,B,U} = TangentBundle{N}(primal, UniformTangent{U}(U.instance)) +UniformBundle{N, B}(primal::B, partial::U) where {N,B,U} = TangentBundle{N}(primal, UniformTangent{U}(partial)) +UniformBundle{N}(primal, partial::U) where {N,U} = TangentBundle{N}(primal, UniformTangent{U}(partial)) +UniformBundle{N, <:Any, U}(primal, partial::U) where {N, U} = TangentBundle{N}(primal, UniformTangent{U}(U.instance)) +UniformBundle{N, <:Any, U}(primal) where {N, U} = TangentBundle{N}(primal, UniformTangent{U}(U.instance)) const ZeroBundle{N, B} = UniformBundle{N, B, ZeroTangent} const DNEBundle{N, B} = UniformBundle{N, B, NoTangent} @@ -288,24 +221,6 @@ end expand_singleton_to_array(asize, a::AbstractZero) = fill(a, asize...) expand_singleton_to_array(asize, a::AbstractArray) = a -function unbundle(atb::ExplicitTangentBundle{Order, A}) where {Order, Dim, T, A<:AbstractArray{T, Dim}} - asize = size(atb.primal) - StructArray{ExplicitTangentBundle{Order, T}}((atb.primal, map(a->expand_singleton_to_array(asize, a), atb.tangent.partials)...)) -end - -function StructArrays.staticschema(::Type{<:ExplicitTangentBundle{N, B, T}}) where {N, B, T} - Tuple{B, T.parameters...} -end - -function StructArrays.component(m::ExplicitTangentBundle{N, B, T}, i::Int) where {N, B, T} - i == 1 && return m.primal - return m.tangent.partials[i - 1] -end - -function StructArrays.createinstance(T::Type{<:ExplicitTangentBundle}, args...) - T(first(args), Base.tail(args)) -end - function unbundle(atb::TaylorBundle{Order, A}) where {Order, Dim, T, A<:AbstractArray{T, Dim}} StructArray{TaylorBundle{Order, T}}((atb.primal, atb.tangent.coeffs...)) end @@ -343,14 +258,6 @@ function StructArrays.createinstance(T::Type{<:ZeroBundle}, args...) T(args[1], args[2]) end -function rebundle(A::AbstractArray{<:ExplicitTangentBundle{N}}) where {N} - ExplicitTangentBundle{N}( - map(x->x.primal, A), - ntuple(2^N-1) do i - map(x->x.tangent.partials[i], A) - end) -end - function rebundle(A::AbstractArray{<:TaylorBundle{N}}) where {N} TaylorBundle{N}( map(x->x.primal, A), diff --git a/test/runtests.jl b/test/runtests.jl index 7c492d73..419bf55e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -50,7 +50,7 @@ ChainRules.rrule(::typeof(my_tuple), args...) = args, Δ->Core.tuple(NoTangent() # Minimal 2-nd order forward smoke test @test Diffractor.∂☆{2}()(Diffractor.ZeroBundle{2}(sin), - Diffractor.ExplicitTangentBundle{2}(1.0, (1.0, 1.0, 0.0)))[Diffractor.CanonicalTangentIndex(1)] == sin'(1.0) + Diffractor.TaylorBundle{2}(1.0, (1.0, 0.0)))[Diffractor.CanonicalTangentIndex(1)] == sin'(1.0) function simple_control_flow(b, x) if b diff --git a/test/stage2_fwd.jl b/test/stage2_fwd.jl index 06f21bd0..bcfc4f8c 100644 --- a/test/stage2_fwd.jl +++ b/test/stage2_fwd.jl @@ -14,7 +14,6 @@ module stage2_fwd self_minus(a) = myminus(a, a) let self_minus′′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(self_minus), Float64}, 2) - # TODO: The IR for this currently contains Union{Diffractor.TangentBundle{2, Float64, Diffractor.ExplicitTangent{Tuple{Float64, Float64, Float64}}}, Diffractor.TangentBundle{2, Float64, Diffractor.TaylorTangent{Tuple{Float64, Float64}}}} # We should have Diffractor be able to prove uniformity @test_broken isa(self_minus′′, Core.OpaqueClosure{Tuple{Float64}, Float64}) @test self_minus′′(1.0) == 0.